Merge branch 'limiter' into 'main'

relay: stricter rate limits

See merge request soapbox-pub/ditto!625
This commit is contained in:
Alex Gleason 2025-01-25 21:37:33 +00:00
commit cd7c77420c
7 changed files with 266 additions and 16 deletions

View file

@ -1,6 +1,6 @@
import { Stickynotes } from '@soapbox/stickynotes';
import TTLCache from '@isaacs/ttlcache';
import {
NKinds,
NostrClientCLOSE,
NostrClientCOUNT,
NostrClientEVENT,
@ -19,14 +19,27 @@ import { RelayError } from '@/RelayError.ts';
import { Storages } from '@/storages.ts';
import { Time } from '@/utils/time.ts';
import { purifyEvent } from '@/utils/purify.ts';
import { MemoryRateLimiter } from '@/utils/ratelimiter/MemoryRateLimiter.ts';
import { MultiRateLimiter } from '@/utils/ratelimiter/MultiRateLimiter.ts';
import { RateLimiter } from '@/utils/ratelimiter/types.ts';
/** Limit of initial events returned for a subscription. */
const FILTER_LIMIT = 100;
const LIMITER_WINDOW = Time.minutes(1);
const LIMITER_LIMIT = 300;
const limiter = new TTLCache<string, number>();
const limiters = {
msg: new MemoryRateLimiter({ limit: 300, window: Time.minutes(1) }),
req: new MultiRateLimiter([
new MemoryRateLimiter({ limit: 15, window: Time.seconds(5) }),
new MemoryRateLimiter({ limit: 300, window: Time.minutes(5) }),
new MemoryRateLimiter({ limit: 1000, window: Time.hours(1) }),
]),
event: new MultiRateLimiter([
new MemoryRateLimiter({ limit: 10, window: Time.seconds(10) }),
new MemoryRateLimiter({ limit: 100, window: Time.hours(1) }),
new MemoryRateLimiter({ limit: 500, window: Time.days(1) }),
]),
ephemeral: new MemoryRateLimiter({ limit: 30, window: Time.seconds(10) }),
};
/** Connections for metrics purposes. */
const connections = new Set<WebSocket>();
@ -43,15 +56,7 @@ function connectStream(socket: WebSocket, ip: string | undefined) {
};
socket.onmessage = (e) => {
if (ip) {
const count = limiter.get(ip) ?? 0;
limiter.set(ip, count + 1, { ttl: LIMITER_WINDOW });
if (count > LIMITER_LIMIT) {
socket.close(1008, 'Rate limit exceeded');
return;
}
}
if (rateLimited(limiters.msg)) return;
if (typeof e.data !== 'string') {
socket.close(1003, 'Invalid message');
@ -77,6 +82,19 @@ function connectStream(socket: WebSocket, ip: string | undefined) {
}
};
function rateLimited(limiter: Pick<RateLimiter, 'client'>): boolean {
if (ip) {
const client = limiter.client(ip);
try {
client.hit();
} catch {
socket.close(1008, 'Rate limit exceeded');
return true;
}
}
return false;
}
/** Handle client message. */
function handleMsg(msg: NostrClientMsg) {
switch (msg[0]) {
@ -97,6 +115,8 @@ function connectStream(socket: WebSocket, ip: string | undefined) {
/** Handle REQ. Start a subscription. */
async function handleReq([_, subId, ...filters]: NostrClientREQ): Promise<void> {
if (rateLimited(limiters.req)) return;
const controller = new AbortController();
controllers.get(subId)?.abort();
controllers.set(subId, controller);
@ -136,6 +156,10 @@ function connectStream(socket: WebSocket, ip: string | undefined) {
/** Handle EVENT. Store the event. */
async function handleEvent([_, event]: NostrClientEVENT): Promise<void> {
relayEventsCounter.inc({ kind: event.kind.toString() });
const limiter = NKinds.ephemeral(event.kind) ? limiters.ephemeral : limiters.event;
if (rateLimited(limiter)) return;
try {
// This will store it (if eligible) and run other side-effects.
await pipeline.handleEvent(purifyEvent(event), { source: 'relay', signal: AbortSignal.timeout(1000) });
@ -161,6 +185,7 @@ function connectStream(socket: WebSocket, ip: string | undefined) {
/** Handle COUNT. Return the number of events matching the filters. */
async function handleCount([_, subId, ...filters]: NostrClientCOUNT): Promise<void> {
if (rateLimited(limiters.req)) return;
const store = await Storages.db();
const { count } = await store.count(filters, { timeout: Conf.db.timeouts.relay });
send(['COUNT', subId, { count, approximate: false }]);
@ -188,8 +213,11 @@ const relayController: AppController = (c, next) => {
const ip = c.req.header('x-real-ip');
if (ip) {
const count = limiter.get(ip) ?? 0;
if (count > LIMITER_LIMIT) {
const remaining = Object
.values(limiters)
.reduce((acc, limiter) => Math.min(acc, limiter.client(ip).remaining), Infinity);
if (remaining < 0) {
return c.json({ error: 'Rate limit exceeded' }, 429);
}
}

View file

@ -0,0 +1,31 @@
import { assertEquals, assertThrows } from '@std/assert';
import { MemoryRateLimiter } from './MemoryRateLimiter.ts';
import { RateLimitError } from './RateLimitError.ts';
Deno.test('MemoryRateLimiter', async (t) => {
const limit = 5;
const window = 100;
using limiter = new MemoryRateLimiter({ limit, window });
await t.step('can hit up to limit', () => {
for (let i = 0; i < limit; i++) {
const client = limiter.client('test');
assertEquals(client.hits, i);
client.hit();
}
});
await t.step('throws when hit if limit exceeded', () => {
assertThrows(() => limiter.client('test').hit(), RateLimitError);
});
await t.step('can hit after window resets', async () => {
await new Promise((resolve) => setTimeout(resolve, window + 1));
const client = limiter.client('test');
assertEquals(client.hits, 0);
client.hit();
});
});

View file

@ -0,0 +1,77 @@
import { RateLimitError } from './RateLimitError.ts';
import { RateLimiter, RateLimiterClient } from './types.ts';
interface MemoryRateLimiterOpts {
limit: number;
window: number;
}
export class MemoryRateLimiter implements RateLimiter {
private iid: number;
private previous = new Map<string, RateLimiterClient>();
private current = new Map<string, RateLimiterClient>();
constructor(private opts: MemoryRateLimiterOpts) {
this.iid = setInterval(() => {
this.previous = this.current;
this.current = new Map();
}, opts.window);
}
get limit(): number {
return this.opts.limit;
}
get window(): number {
return this.opts.window;
}
client(key: string): RateLimiterClient {
const curr = this.current.get(key);
const prev = this.previous.get(key);
if (curr) {
return curr;
}
if (prev && prev.resetAt > new Date()) {
this.current.set(key, prev);
this.previous.delete(key);
return prev;
}
const next = new MemoryRateLimiterClient(this);
this.current.set(key, next);
return next;
}
[Symbol.dispose](): void {
clearInterval(this.iid);
}
}
class MemoryRateLimiterClient implements RateLimiterClient {
private _hits: number = 0;
readonly resetAt: Date;
constructor(private limiter: MemoryRateLimiter) {
this.resetAt = new Date(Date.now() + limiter.window);
}
get hits(): number {
return this._hits;
}
get remaining(): number {
return this.limiter.limit - this.hits;
}
hit(n: number = 1): void {
this._hits += n;
if (this.remaining < 0) {
throw new RateLimitError(this.limiter, this);
}
}
}

View file

@ -0,0 +1,41 @@
import { assertEquals, assertThrows } from '@std/assert';
import { MemoryRateLimiter } from './MemoryRateLimiter.ts';
import { MultiRateLimiter } from './MultiRateLimiter.ts';
Deno.test('MultiRateLimiter', async (t) => {
using limiter1 = new MemoryRateLimiter({ limit: 5, window: 100 });
using limiter2 = new MemoryRateLimiter({ limit: 8, window: 200 });
const limiter = new MultiRateLimiter([limiter1, limiter2]);
await t.step('can hit up to first limit', () => {
for (let i = 0; i < limiter1.limit; i++) {
const client = limiter.client('test');
assertEquals(client.hits, i);
client.hit();
}
});
await t.step('throws when hit if first limit exceeded', () => {
assertThrows(() => limiter.client('test').hit(), Error);
});
await t.step('can hit up to second limit after the first window resets', async () => {
await new Promise((resolve) => setTimeout(resolve, limiter1.window + 1));
const limit = limiter2.limit - limiter1.limit - 1;
for (let i = 0; i < limit; i++) {
const client = limiter.client('test');
assertEquals(client.hits, i);
client.hit();
}
});
await t.step('throws when hit if second limit exceeded', () => {
assertEquals(limiter.client('test').limiter, limiter1);
assertThrows(() => limiter.client('test').hit(), Error);
assertEquals(limiter.client('test').limiter, limiter2);
});
});

View file

@ -0,0 +1,51 @@
import { RateLimiter, RateLimiterClient } from './types.ts';
export class MultiRateLimiter {
constructor(private limiters: RateLimiter[]) {}
client(key: string): MultiRateLimiterClient {
return new MultiRateLimiterClient(key, this.limiters);
}
}
class MultiRateLimiterClient implements RateLimiterClient {
constructor(private key: string, private limiters: RateLimiter[]) {
if (!limiters.length) {
throw new Error('No limiters provided');
}
}
/** Returns the _active_ limiter, which is either the first exceeded or the first. */
get limiter(): RateLimiter {
const exceeded = this.limiters.find((limiter) => limiter.client(this.key).remaining < 0);
return exceeded ?? this.limiters[0];
}
get hits(): number {
return this.limiter.client(this.key).hits;
}
get resetAt(): Date {
return this.limiter.client(this.key).resetAt;
}
get remaining(): number {
return this.limiter.client(this.key).remaining;
}
hit(n?: number): void {
let error: unknown;
for (const limiter of this.limiters) {
try {
limiter.client(this.key).hit(n);
} catch (e) {
error ??= e;
}
}
if (error instanceof Error) {
throw error;
}
}
}

View file

@ -0,0 +1,10 @@
import { RateLimiter, RateLimiterClient } from './types.ts';
export class RateLimitError extends Error {
constructor(
readonly limiter: RateLimiter,
readonly client: RateLimiterClient,
) {
super('Rate limit exceeded');
}
}

View file

@ -0,0 +1,12 @@
export interface RateLimiter extends Disposable {
readonly limit: number;
readonly window: number;
client(key: string): RateLimiterClient;
}
export interface RateLimiterClient {
readonly hits: number;
readonly resetAt: Date;
readonly remaining: number;
hit(n?: number): void;
}