diff --git a/src/controllers/nostr/relay.ts b/src/controllers/nostr/relay.ts index 93ffb199..6b1c2fbc 100644 --- a/src/controllers/nostr/relay.ts +++ b/src/controllers/nostr/relay.ts @@ -56,7 +56,7 @@ function connectStream(socket: WebSocket, ip: string | undefined) { }; socket.onmessage = (e) => { - assertRateLimit(limiters.msg); + if (rateLimited(limiters.msg)) return; if (typeof e.data !== 'string') { socket.close(1003, 'Invalid message'); @@ -82,16 +82,17 @@ function connectStream(socket: WebSocket, ip: string | undefined) { } }; - function assertRateLimit(limiter: Pick): void { + function rateLimited(limiter: Pick): boolean { if (ip) { const client = limiter.client(ip); try { client.hit(); - } catch (error) { + } catch { socket.close(1008, 'Rate limit exceeded'); - throw error; + return true; } } + return false; } /** Handle client message. */ @@ -114,7 +115,7 @@ function connectStream(socket: WebSocket, ip: string | undefined) { /** Handle REQ. Start a subscription. */ async function handleReq([_, subId, ...filters]: NostrClientREQ): Promise { - assertRateLimit(limiters.req); + if (rateLimited(limiters.req)) return; const controller = new AbortController(); controllers.get(subId)?.abort(); @@ -156,11 +157,8 @@ function connectStream(socket: WebSocket, ip: string | undefined) { async function handleEvent([_, event]: NostrClientEVENT): Promise { relayEventsCounter.inc({ kind: event.kind.toString() }); - if (NKinds.ephemeral(event.kind)) { - assertRateLimit(limiters.ephemeral); - } else { - assertRateLimit(limiters.event); - } + const limiter = NKinds.ephemeral(event.kind) ? limiters.ephemeral : limiters.event; + if (rateLimited(limiter)) return; try { // This will store it (if eligible) and run other side-effects. @@ -187,7 +185,7 @@ function connectStream(socket: WebSocket, ip: string | undefined) { /** Handle COUNT. Return the number of events matching the filters. */ async function handleCount([_, subId, ...filters]: NostrClientCOUNT): Promise { - assertRateLimit(limiters.req); + if (rateLimited(limiters.req)) return; const store = await Storages.db(); const { count } = await store.count(filters, { timeout: Conf.db.timeouts.relay }); send(['COUNT', subId, { count, approximate: false }]);