diff --git a/apps/bot/package.json b/apps/bot/package.json index 31e65fa..eea1764 100644 --- a/apps/bot/package.json +++ b/apps/bot/package.json @@ -18,6 +18,7 @@ "@whiskeysockets/baileys": "7.0.0-rc10", "drizzle-orm": "^0.36.0", "luxon": "^3.5.0", + "p-limit": "^7.3.0", "pg": "^8.13.0", "pg-boss": "^12.18.2", "pino": "^9.5.0", diff --git a/apps/bot/src/scheduler/fire-reminder.test.ts b/apps/bot/src/scheduler/fire-reminder.test.ts new file mode 100644 index 0000000..d59886e --- /dev/null +++ b/apps/bot/src/scheduler/fire-reminder.test.ts @@ -0,0 +1,128 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +// Mock the per-key mutex module BEFORE importing fire-reminder so the +// runtime sees our spy when it dereferences `accountMutex.run`. +vi.mock("./per-key-mutex.js", () => { + return { + PerKeyMutex: class {}, + accountMutex: { + run: vi.fn(async (_key: string, fn: () => Promise) => fn()), + }, + }; +}); + +// Stub everything fire-reminder pulls in so the import succeeds without +// actually starting a Baileys session, hitting the DB, or talking to +// pg-boss. +const getReminderMock = vi.fn(); +vi.mock("../reminders/crud.js", () => ({ + getReminderWithDetails: (...args: unknown[]) => getReminderMock(...args), +})); +vi.mock("../db.js", () => ({ + db: { + insert: () => ({ values: () => ({ returning: async () => [{ id: "run-1" }] }) }), + update: () => ({ set: () => ({ where: async () => undefined }) }), + query: { + whatsappGroups: { findMany: async () => [] }, + mediaFiles: { findMany: async () => [] }, + }, + }, +})); +vi.mock("../whatsapp/session-manager.js", () => ({ + sessionManager: { getSession: () => null }, +})); +vi.mock("../ipc/notify.js", () => ({ pgNotifyWeb: vi.fn(async () => undefined) })); +vi.mock("../audit.js", () => ({ writeAuditLog: vi.fn(async () => undefined) })); +vi.mock("./pgboss-client.js", () => ({ getBoss: () => ({}) })); +vi.mock("./reminder-jobs.js", () => ({ scheduleReminderFire: vi.fn() })); + +import { fireReminder } from "./fire-reminder.js"; +import { accountMutex } from "./per-key-mutex.js"; + +describe("fireReminder", () => { + beforeEach(() => { + vi.mocked(accountMutex.run).mockClear(); + getReminderMock.mockReset(); + }); + + it("acquires accountMutex keyed by accountId for active reminders", async () => { + getReminderMock.mockResolvedValue({ + id: "r-1", + accountId: "acct-A", + status: "active", + targets: [], + messages: [], + createdBy: "op-1", + scheduleKind: "one_off", + rrule: null, + timezone: "Asia/Kuala_Lumpur", + name: "Test", + }); + + await fireReminder({ reminderId: "r-1" }); + + expect(accountMutex.run).toHaveBeenCalledTimes(1); + expect(accountMutex.run).toHaveBeenCalledWith("acct-A", expect.any(Function)); + }); + + it("does NOT acquire the mutex when the reminder is inactive", async () => { + getReminderMock.mockResolvedValue({ + id: "r-1", + accountId: "acct-A", + status: "ended", + targets: [], + messages: [], + createdBy: "op-1", + scheduleKind: "one_off", + rrule: null, + timezone: "Asia/Kuala_Lumpur", + name: "Test", + }); + + await fireReminder({ reminderId: "r-1" }); + + expect(accountMutex.run).not.toHaveBeenCalled(); + }); + + it("does NOT acquire the mutex when the reminder row is missing", async () => { + getReminderMock.mockResolvedValue(undefined); + + await fireReminder({ reminderId: "r-missing" }); + + expect(accountMutex.run).not.toHaveBeenCalled(); + }); + + it("uses different mutex keys for different accounts (cross-account isolation)", async () => { + getReminderMock.mockResolvedValueOnce({ + id: "r-A", + accountId: "acct-A", + status: "active", + targets: [], + messages: [], + createdBy: "op-1", + scheduleKind: "one_off", + rrule: null, + timezone: "Asia/Kuala_Lumpur", + name: "A", + }); + getReminderMock.mockResolvedValueOnce({ + id: "r-B", + accountId: "acct-B", + status: "active", + targets: [], + messages: [], + createdBy: "op-1", + scheduleKind: "one_off", + rrule: null, + timezone: "Asia/Kuala_Lumpur", + name: "B", + }); + + await fireReminder({ reminderId: "r-A" }); + await fireReminder({ reminderId: "r-B" }); + + const calls = vi.mocked(accountMutex.run).mock.calls; + expect(calls[0]?.[0]).toBe("acct-A"); + expect(calls[1]?.[0]).toBe("acct-B"); + }); +}); diff --git a/apps/bot/src/scheduler/fire-reminder.ts b/apps/bot/src/scheduler/fire-reminder.ts index 764d71e..585ce52 100644 --- a/apps/bot/src/scheduler/fire-reminder.ts +++ b/apps/bot/src/scheduler/fire-reminder.ts @@ -1,34 +1,60 @@ -import { eq } from "drizzle-orm"; +import { and, eq, inArray } from "drizzle-orm"; import { reminderRuns, reminderRunTargets, reminders } from "@cmbot/db"; +import { + generateWAMessageContent, + generateMessageID, + type AnyMessageContent, + type proto, + type WASocket, +} from "@whiskeysockets/baileys"; +import pLimit from "p-limit"; +import { readFile } from "node:fs/promises"; import { db } from "../db.js"; import { logger } from "../logger.js"; import { sessionManager } from "../whatsapp/session-manager.js"; -import { sendTextToGroup, sendMediaToGroup } from "../whatsapp/sender.js"; import { absoluteMediaPath, nextOccurrence, resolveDeliveryKind } from "@cmbot/shared"; -import { open as fsOpen } from "node:fs/promises"; import { env } from "../env.js"; import { writeAuditLog } from "../audit.js"; import { getReminderWithDetails } from "../reminders/crud.js"; import { getBoss } from "./pgboss-client.js"; import { scheduleReminderFire } from "./reminder-jobs.js"; import { pgNotifyWeb } from "../ipc/notify.js"; +import { accountMutex } from "./per-key-mutex.js"; +import { accountRateLimiter } from "./rate-limiter.js"; +import { MediaUploadCache } from "./media-upload-cache.js"; export type FireReminderPayload = { reminderId: string }; -/** - * Read the first N bytes of a file without slurping the whole thing. - * Used to sniff ISOBMFF brand bytes (HEIF, AVIF, QuickTime) so we - * can route mis-labelled uploads to the document path instead of - * letting Baileys' thumbnail extraction crash. - */ -async function readHeadBytes(filePath: string, n: number): Promise { - const fh = await fsOpen(filePath, "r"); - try { - const buf = new Uint8Array(n); - await fh.read(buf, 0, n, 0); - return buf; - } finally { - await fh.close(); +/** Random delay between same-group message parts. Just enough for + * visible ordering in the chat at WA's natural pace. */ +function partJitterMs(): number { + return 200 + Math.floor(Math.random() * 300); // 200..499 +} + +/** Baileys's WASocket exposes assertSessions on its internal interface, + * but it isn't part of the public type. Call it once per group before + * the first send so relayMessage doesn't trip on missing sessions. */ +type SocketWithAssertSessions = WASocket & { + assertSessions?: (jids: string[], force: boolean) => Promise; +}; + +async function ensureGroupSessions(socket: WASocket, groupJid: string): Promise { + const internal = socket as SocketWithAssertSessions; + if (typeof internal.assertSessions !== "function") return; + const meta = await socket.groupMetadata(groupJid); + const participantJids = meta.participants.map((p) => p.id); + // Chunk so a single bad participant doesn't fail the whole group. + const CHUNK = 5; + for (let i = 0; i < participantJids.length; i += CHUNK) { + const chunk = participantJids.slice(i, i + CHUNK); + try { + await internal.assertSessions(chunk, true); + } catch (err) { + logger.warn( + { groupJid, err: (err as Error).message }, + "fire-reminder: assertSessions chunk failed", + ); + } } } @@ -43,12 +69,19 @@ export async function fireReminder(payload: FireReminderPayload): Promise return; } + // Per-account mutex: two reminders on the SAME account take turns + // (running them concurrently would double the effective send rate + // and risk a ban). Different accounts run in parallel. + await accountMutex.run(reminder.accountId, () => fireReminderInner(reminder)); +} + +async function fireReminderInner( + reminder: NonNullable>>, +): Promise { const [run] = await db .insert(reminderRuns) .values({ reminderId: reminder.id, - // Snapshot the name so the run row stays readable in history even - // after the reminder is deleted (FK is ON DELETE SET NULL). reminderName: reminder.name, status: "pending", }) @@ -58,120 +91,177 @@ export async function fireReminder(payload: FireReminderPayload): Promise const session = sessionManager.getSession(reminder.accountId); if (!session) { logger.warn({ reminderId: reminder.id }, "fire-reminder: account not connected"); - for (const target of reminder.targets) { - const g = await db.query.whatsappGroups.findFirst({ - where: (g, { eq }) => eq(g.id, target.groupId), - columns: { name: true }, - }); - await db.insert(reminderRunTargets).values({ - runId, - groupId: target.groupId, - groupLabel: g?.name ?? null, - status: "skipped", - error: "account not connected", - }); - } + await markAllSkipped(runId, reminder, "account not connected"); await db .update(reminderRuns) .set({ status: "skipped", errorSummary: "account not connected" }) .where(eq(reminderRuns.id, runId)); + await pgNotifyWeb({ type: "reminder.fired", reminderId: reminder.id, runId, status: "skipped" }); return; } - let allSent = true; - let anySent = false; - for (const target of reminder.targets) { - const group = await db.query.whatsappGroups.findFirst({ - where: (g, { eq }) => eq(g.id, target.groupId), - }); - if (!group) { - await db.insert(reminderRunTargets).values({ + // Up-front bulk loads. Drops ~3000 round-trips to ~3 for a 1000-group run. + const groupIds = reminder.targets.map((t) => t.groupId); + const groupRows = groupIds.length + ? await db.query.whatsappGroups.findMany({ where: (g) => inArray(g.id, groupIds) }) + : []; + const groupById = new Map(groupRows.map((g) => [g.id, g])); + + const mediaIds = Array.from( + new Set(reminder.messages.map((m) => m.mediaId).filter((id): id is string => Boolean(id))), + ); + const mediaRows = mediaIds.length + ? await db.query.mediaFiles.findMany({ where: (m) => inArray(m.id, mediaIds) }) + : []; + const mediaById = new Map(mediaRows.map((m) => [m.id, m])); + + // Pre-create run_target rows so the Activity tab shows progress mid-run. + if (reminder.targets.length > 0) { + await db.insert(reminderRunTargets).values( + reminder.targets.map((t) => ({ runId, - groupId: target.groupId, - groupLabel: null, - status: "skipped", - error: "group missing from db", - }); - allSent = false; - continue; - } - const start = Date.now(); - try { - let lastMessageId: string | undefined; - for (const part of reminder.messages) { - if (part.kind === "text" && part.textContent) { - const r = await sendTextToGroup(session.socket, group.waGroupJid, part.textContent); - lastMessageId = r.messageId; - } else if (part.mediaId) { - const media = await db.query.mediaFiles.findFirst({ - where: (m, { eq }) => eq(m.id, part.mediaId!), - }); - if (!media) throw new Error(`media row missing: ${part.mediaId}`); - const filePath = absoluteMediaPath(media.storagePath, env.MEDIA_DIR); - // Resolve the actual delivery kind from mime + magic bytes. - // Sniffing the first 12 bytes catches HEIC/MOV uploads - // labelled with a misleading mime (e.g. iOS Safari) and - // routes them to the document path so the bot doesn't try - // to extract a thumbnail it can't decode. - const head = await readHeadBytes(filePath, 12); - const resolved = resolveDeliveryKind(media.mimeType, head); - // sendMediaToGroup accepts image / video / document. Audio - // collapses into the document path for now; the per-kind - // size cap was already applied at upload time. - const senderKind: "image" | "video" | "document" = - resolved === "image" || resolved === "video" ? resolved : "document"; - const r = await sendMediaToGroup(session.socket, group.waGroupJid, senderKind, filePath, { - caption: part.textContent ?? undefined, - mimeType: media.mimeType, - filename: media.filenameOriginal, - }); - lastMessageId = r.messageId; - } - // 1.5s jitter between message parts to stay under WA's rate limit - await new Promise((r) => setTimeout(r, 1500)); - } - await db.insert(reminderRunTargets).values({ - runId, - groupId: target.groupId, - groupLabel: group.name, - status: "sent", - waMessageId: lastMessageId ?? null, - latencyMs: Date.now() - start, - }); - anySent = true; - } catch (err) { - logger.error({ err, reminderId: reminder.id, groupId: target.groupId }, "fire-reminder: send failed"); - await db.insert(reminderRunTargets).values({ - runId, - groupId: target.groupId, - groupLabel: group.name, - status: "failed", - error: (err as Error).message, - }); - allSent = false; - } + groupId: t.groupId, + groupLabel: groupById.get(t.groupId)?.name ?? null, + status: "pending" as const, + })), + ); } - const status = allSent ? "success" : anySent ? "partial" : "failed"; - await db - .update(reminderRuns) - .set({ status }) - .where(eq(reminderRuns.id, runId)); - - // Notify the web so any open browsers can fire a notification. - // The web UI subscribes to `reminder.fired` via SSE and surfaces - // it as a desktop / mobile notification when the operator has - // opted in (Notification.permission === "granted"). - await pgNotifyWeb({ - type: "reminder.fired", - reminderId: reminder.id, - runId, - status, + // Per-run media upload cache. Each unique mediaId is prepared via + // generateWAMessageContent ONCE (which uploads to WA's CDN through + // the socket's waUploadToServer); the resulting proto.Message is + // reused for every group via socket.relayMessage. For 1000 groups + // × 5 MB image, this turns 5 GB of upload into 5 MB. + const uploadCache = new MediaUploadCache(async (mediaId) => { + const media = mediaById.get(mediaId); + if (!media) throw new Error(`media row missing: ${mediaId}`); + const filePath = absoluteMediaPath(media.storagePath, env.MEDIA_DIR); + const buffer = await readFile(filePath); + const head = buffer.subarray(0, 12); + const resolved = resolveDeliveryKind(media.mimeType, head); + const senderKind: "image" | "video" | "document" = + resolved === "image" || resolved === "video" ? resolved : "document"; + const content: AnyMessageContent = + senderKind === "image" + ? { image: buffer, mimetype: media.mimeType } + : senderKind === "video" + ? { video: buffer, mimetype: media.mimeType } + : { + document: buffer, + fileName: media.filenameOriginal, + mimetype: media.mimeType, + }; + return generateWAMessageContent(content, { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + upload: (session.socket as any).waUploadToServer, + }); }); - // One-off reminders end after firing. Recurring reminders compute the - // next occurrence from the RRULE and re-arm the pg-boss job; only the - // last fire timestamp + updatedAt move forward. + // Per-account rate limiter — gates each socket send to stay within + // the account's safe band (BOT_MAX_SEND_PER_MINUTE, default 40). + const rateLimiter = accountRateLimiter.get(reminder.accountId); + + let sentCount = 0; + let failedCount = 0; + let skippedCount = 0; + + const groupConcurrency = pLimit(env.BOT_GROUP_CONCURRENCY); + + await Promise.all( + reminder.targets.map((target) => + groupConcurrency(async () => { + const group = groupById.get(target.groupId); + if (!group) { + await db + .update(reminderRunTargets) + .set({ status: "skipped", error: "group missing from db" }) + .where( + and( + eq(reminderRunTargets.runId, runId), + eq(reminderRunTargets.groupId, target.groupId), + ), + ); + skippedCount++; + return; + } + + const start = Date.now(); + try { + // Once per group, before the first send. sendMessage handles + // sessions internally; relayMessage does not. + await ensureGroupSessions(session.socket, group.waGroupJid); + + let lastMessageId: string | undefined; + for (const part of reminder.messages) { + await rateLimiter.acquire(); + if (part.kind === "text" && part.textContent) { + const r = await session.socket.sendMessage(group.waGroupJid, { + text: part.textContent, + }); + lastMessageId = r?.key?.id ?? undefined; + } else if (part.mediaId) { + const prebuilt = await uploadCache.get(part.mediaId); + if (part.textContent) injectCaption(prebuilt, part.textContent); + const messageId = generateMessageID(); + await session.socket.relayMessage(group.waGroupJid, prebuilt, { messageId }); + lastMessageId = messageId; + } + await new Promise((r) => setTimeout(r, partJitterMs())); + } + await db + .update(reminderRunTargets) + .set({ + status: "sent", + waMessageId: lastMessageId ?? null, + latencyMs: Date.now() - start, + }) + .where( + and( + eq(reminderRunTargets.runId, runId), + eq(reminderRunTargets.groupId, target.groupId), + ), + ); + sentCount++; + } catch (err) { + logger.error( + { err, reminderId: reminder.id, groupId: target.groupId }, + "fire-reminder: send failed", + ); + await db + .update(reminderRunTargets) + .set({ status: "failed", error: (err as Error).message }) + .where( + and( + eq(reminderRunTargets.runId, runId), + eq(reminderRunTargets.groupId, target.groupId), + ), + ); + failedCount++; + } + }), + ), + ); + + const total = reminder.targets.length; + let status: "success" | "partial" | "failed"; + let errorSummary: string | null = null; + if (sentCount === total) { + status = "success"; + } else if (sentCount > 0) { + status = "partial"; + errorSummary = `${sentCount} of ${total} groups delivered (${failedCount} failed, ${skippedCount} skipped).`; + } else { + status = "failed"; + errorSummary = total === 0 ? "No targets attached to reminder." : `All ${total} sends failed.`; + } + + await db + .update(reminderRuns) + .set({ status, errorSummary }) + .where(eq(reminderRuns.id, runId)); + + await pgNotifyWeb({ type: "reminder.fired", reminderId: reminder.id, runId, status }); + if (reminder.scheduleKind === "one_off") { await db .update(reminders) @@ -202,8 +292,44 @@ export async function fireReminder(payload: FireReminderPayload): Promise action: "reminder.fired", targetType: "reminder", targetId: reminder.id, - payload: { runId, status }, + payload: { runId, status, sent: sentCount, failed: failedCount, skipped: skippedCount }, }); - logger.info({ reminderId: reminder.id, runId, status }, "fire-reminder: done"); + logger.info( + { reminderId: reminder.id, runId, status, sent: sentCount, failed: failedCount, skipped: skippedCount }, + "fire-reminder: done", + ); +} + +async function markAllSkipped( + runId: string, + reminder: NonNullable>>, + error: string, +): Promise { + if (reminder.targets.length === 0) return; + const rows = await db.query.whatsappGroups.findMany({ + where: (g) => inArray(g.id, reminder.targets.map((t) => t.groupId)), + columns: { id: true, name: true }, + }); + const labelById = new Map(rows.map((r) => [r.id, r.name])); + await db.insert(reminderRunTargets).values( + reminder.targets.map((t) => ({ + runId, + groupId: t.groupId, + groupLabel: labelById.get(t.groupId) ?? null, + status: "skipped" as const, + error, + })), + ); +} + +/** + * Inject the caption into the prebuilt media message. Baileys' relayMessage + * doesn't take a caption alongside the content; the protobuf already has + * the slot, so we mutate it just before relaying. + */ +function injectCaption(msg: proto.IMessage, caption: string): void { + if (msg.imageMessage) msg.imageMessage.caption = caption; + else if (msg.videoMessage) msg.videoMessage.caption = caption; + else if (msg.documentMessage) msg.documentMessage.caption = caption; } diff --git a/apps/bot/src/scheduler/reminder-jobs.ts b/apps/bot/src/scheduler/reminder-jobs.ts index 61dbf98..ce948dd 100644 --- a/apps/bot/src/scheduler/reminder-jobs.ts +++ b/apps/bot/src/scheduler/reminder-jobs.ts @@ -1,18 +1,32 @@ import type { PgBoss } from "pg-boss"; import { logger } from "../logger.js"; +import { env } from "../env.js"; import { fireReminder, type FireReminderPayload } from "./fire-reminder.js"; export const REMINDER_FIRE_QUEUE = "reminder.fire"; export async function registerReminderJobs(boss: PgBoss): Promise { await boss.createQueue(REMINDER_FIRE_QUEUE); - await boss.work(REMINDER_FIRE_QUEUE, async (jobs) => { - const job = jobs[0]; - if (!job) return; - logger.debug({ jobId: job.id, payload: job.data }, "reminder.fire: handling"); - await fireReminder(job.data); - }); - logger.info("reminder.fire: handler registered"); + await boss.work( + REMINDER_FIRE_QUEUE, + { + // Up to BOT_FIRE_CONCURRENCY workers per node, each polling and + // processing independently. Combined with the per-account mutex + // inside fireReminder, this lets reminders on DIFFERENT accounts + // run in parallel while same-account reminders take turns. + localConcurrency: env.BOT_FIRE_CONCURRENCY, + }, + async (jobs) => { + const job = jobs[0]; + if (!job) return; + logger.debug({ jobId: job.id, payload: job.data }, "reminder.fire: handling"); + await fireReminder(job.data); + }, + ); + logger.info( + { localConcurrency: env.BOT_FIRE_CONCURRENCY }, + "reminder.fire: handler registered", + ); } export async function scheduleReminderFire( diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 4321d28..8a6002d 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -32,6 +32,9 @@ importers: luxon: specifier: ^3.5.0 version: 3.7.2 + p-limit: + specifier: ^7.3.0 + version: 7.3.0 pg: specifier: ^8.13.0 version: 8.20.0 @@ -3713,6 +3716,10 @@ packages: resolution: {integrity: sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==} engines: {node: '>=6'} + p-limit@7.3.0: + resolution: {integrity: sha512-7cIXg/Z0M5WZRblrsOla88S4wAK+zOQQWeBYfV3qJuJXMr+LnbYjaadrFaS0JILfEDPVqHyKnZ1Z/1d6J9VVUw==} + engines: {node: '>=20'} + p-locate@4.1.0: resolution: {integrity: sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==} engines: {node: '>=8'} @@ -4606,6 +4613,10 @@ packages: resolution: {integrity: sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==} engines: {node: '>=12'} + yocto-queue@1.2.2: + resolution: {integrity: sha512-4LCcse/U2MHZ63HAJVE+v71o7yOdIe4cZ70Wpf8D/IyjDKYQLV5GD46B+hSTjJsvV5PztjvHoU580EftxjDZFQ==} + engines: {node: '>=12.20'} + yocto-spinner@1.2.0: resolution: {integrity: sha512-Yw0hUB6UA3o4YUgKy3oSe9a4cxoaZ9sBfYDw+JSxo6Id0KoJGoxzPA24qqUXYKBWABs/zDSGTz9kww7t3F0XGw==} engines: {node: '>=18.19'} @@ -7734,6 +7745,10 @@ snapshots: dependencies: p-try: 2.2.0 + p-limit@7.3.0: + dependencies: + yocto-queue: 1.2.2 + p-locate@4.1.0: dependencies: p-limit: 2.3.0 @@ -8748,6 +8763,8 @@ snapshots: y18n: 5.0.8 yargs-parser: 21.1.1 + yocto-queue@1.2.2: {} + yocto-spinner@1.2.0: dependencies: yoctocolors: 2.1.2