diff --git a/.changeset/proud-nails-grin.md b/.changeset/proud-nails-grin.md new file mode 100644 index 0000000000..9c4f822f6c --- /dev/null +++ b/.changeset/proud-nails-grin.md @@ -0,0 +1,5 @@ +--- +"@trigger.dev/core": patch +--- + +Add optional placement tags to dequeued messages for targeted scheduling diff --git a/apps/supervisor/src/env.ts b/apps/supervisor/src/env.ts index fd6bd61050..dfe5237912 100644 --- a/apps/supervisor/src/env.ts +++ b/apps/supervisor/src/env.ts @@ -15,7 +15,7 @@ const Env = z.object({ OTEL_EXPORTER_OTLP_ENDPOINT: z.string().url(), // set on the runners // Workload API settings (coordinator mode) - the workload API is what the run controller connects to - TRIGGER_WORKLOAD_API_ENABLED: BoolEnv.default("true"), + TRIGGER_WORKLOAD_API_ENABLED: BoolEnv.default(true), TRIGGER_WORKLOAD_API_PROTOCOL: z .string() .transform((s) => z.enum(["http", "https"]).parse(s.toLowerCase())) @@ -32,7 +32,7 @@ const Env = z.object({ RUNNER_PRETTY_LOGS: BoolEnv.default(false), // Dequeue settings (provider mode) - TRIGGER_DEQUEUE_ENABLED: BoolEnv.default("true"), + TRIGGER_DEQUEUE_ENABLED: BoolEnv.default(true), TRIGGER_DEQUEUE_INTERVAL_MS: z.coerce.number().int().default(250), TRIGGER_DEQUEUE_IDLE_INTERVAL_MS: z.coerce.number().int().default(1000), TRIGGER_DEQUEUE_MAX_RUN_COUNT: z.coerce.number().int().default(10), @@ -77,6 +77,10 @@ const Env = z.object({ KUBERNETES_EPHEMERAL_STORAGE_SIZE_LIMIT: z.string().default("10Gi"), KUBERNETES_EPHEMERAL_STORAGE_SIZE_REQUEST: z.string().default("2Gi"), + // Placement tags settings + PLACEMENT_TAGS_ENABLED: BoolEnv.default(false), + PLACEMENT_TAGS_PREFIX: z.string().default("node.cluster.x-k8s.io"), + // Metrics METRICS_ENABLED: BoolEnv.default(true), METRICS_COLLECT_DEFAULTS: BoolEnv.default(true), diff --git a/apps/supervisor/src/envUtil.ts b/apps/supervisor/src/envUtil.ts index 95d44d6c45..917f984cc3 100644 --- a/apps/supervisor/src/envUtil.ts +++ b/apps/supervisor/src/envUtil.ts @@ -3,7 +3,7 @@ import { SimpleStructuredLogger } from "@trigger.dev/core/v3/utils/structuredLog const logger = new SimpleStructuredLogger("env-util"); -export const BoolEnv = z.preprocess((val) => { +const baseBoolEnv = z.preprocess((val) => { if (typeof val !== "string") { return val; } @@ -11,6 +11,11 @@ export const BoolEnv = z.preprocess((val) => { return ["true", "1"].includes(val.toLowerCase().trim()); }, z.boolean()); +// Create a type-safe version that only accepts boolean defaults +export const BoolEnv = baseBoolEnv as Omit & { + default: (value: boolean) => z.ZodDefault; +}; + export const AdditionalEnvVars = z.preprocess((val) => { if (typeof val !== "string") { return val; diff --git a/apps/supervisor/src/index.ts b/apps/supervisor/src/index.ts index 83fe89c1ed..1ed00edad6 100644 --- a/apps/supervisor/src/index.ts +++ b/apps/supervisor/src/index.ts @@ -247,6 +247,7 @@ class ManagedSupervisor { nextAttemptNumber: message.run.attemptNumber, snapshotId: message.snapshot.id, snapshotFriendlyId: message.snapshot.friendlyId, + placementTags: message.placementTags, }); // Disabled for now diff --git a/apps/supervisor/src/workloadManager/kubernetes.ts b/apps/supervisor/src/workloadManager/kubernetes.ts index 81618e8eb5..2b5547f3a8 100644 --- a/apps/supervisor/src/workloadManager/kubernetes.ts +++ b/apps/supervisor/src/workloadManager/kubernetes.ts @@ -4,7 +4,7 @@ import { type WorkloadManagerCreateOptions, type WorkloadManagerOptions, } from "./types.js"; -import type { EnvironmentType, MachinePreset } from "@trigger.dev/core/v3"; +import type { EnvironmentType, MachinePreset, PlacementTag } from "@trigger.dev/core/v3"; import { env } from "../env.js"; import { type K8sApi, createK8sApi, type k8s } from "../clients/kubernetes.js"; import { getRunnerId } from "../util.js"; @@ -13,6 +13,11 @@ type ResourceQuantities = { [K in "cpu" | "memory" | "ephemeral-storage"]?: string; }; +interface PlacementConfig { + enabled: boolean; + prefix: string; +} + export class KubernetesWorkloadManager implements WorkloadManager { private readonly logger = new SimpleStructuredLogger("kubernetes-workload-provider"); private k8s: K8sApi; @@ -28,6 +33,56 @@ export class KubernetesWorkloadManager implements WorkloadManager { } } + private get placementConfig(): PlacementConfig { + return { + enabled: env.PLACEMENT_TAGS_ENABLED, + prefix: env.PLACEMENT_TAGS_PREFIX, + }; + } + + private addPlacementTags( + podSpec: Omit, + placementTags?: PlacementTag[] + ): Omit { + if (!this.placementConfig.enabled || !placementTags || placementTags.length === 0) { + return podSpec; + } + + const nodeSelector: Record = { ...podSpec.nodeSelector }; + + // Convert placement tags to nodeSelector labels + for (const tag of placementTags) { + const labelKey = `${this.placementConfig.prefix}/${tag.key}`; + + // Print warnings (if any) + this.printTagWarnings(tag); + + // For now we only support single values via nodeSelector + nodeSelector[labelKey] = tag.values?.[0] ?? ""; + } + + return { + ...podSpec, + nodeSelector, + }; + } + + private printTagWarnings(tag: PlacementTag) { + if (!tag.values || tag.values.length === 0) { + // No values provided + this.logger.warn( + "[KubernetesWorkloadManager] Placement tag has no values, using empty string", + tag + ); + } else if (tag.values.length > 1) { + // Multiple values provided + this.logger.warn( + "[KubernetesWorkloadManager] Placement tag has multiple values, only using first one", + tag + ); + } + } + async create(opts: WorkloadManagerCreateOptions) { this.logger.log("[KubernetesWorkloadManager] Creating container", { opts }); @@ -48,7 +103,7 @@ export class KubernetesWorkloadManager implements WorkloadManager { }, }, spec: { - ...this.#defaultPodSpec, + ...this.addPlacementTags(this.#defaultPodSpec, opts.placementTags), terminationGracePeriodSeconds: 60 * 60, containers: [ { diff --git a/apps/supervisor/src/workloadManager/types.ts b/apps/supervisor/src/workloadManager/types.ts index b3cd418f1e..64573fb3b9 100644 --- a/apps/supervisor/src/workloadManager/types.ts +++ b/apps/supervisor/src/workloadManager/types.ts @@ -1,4 +1,4 @@ -import { type EnvironmentType, type MachinePreset } from "@trigger.dev/core/v3"; +import type { EnvironmentType, MachinePreset, PlacementTag } from "@trigger.dev/core/v3"; export interface WorkloadManagerOptions { workloadApiProtocol: "http" | "https"; @@ -23,6 +23,7 @@ export interface WorkloadManagerCreateOptions { version: string; nextAttemptNumber?: number; dequeuedAt: Date; + placementTags?: PlacementTag[]; // identifiers envId: string; envType: EnvironmentType; diff --git a/apps/webapp/app/env.server.ts b/apps/webapp/app/env.server.ts index 6a78a1cc95..1a49acddbc 100644 --- a/apps/webapp/app/env.server.ts +++ b/apps/webapp/app/env.server.ts @@ -761,6 +761,8 @@ const EnvironmentSchema = z.object({ .int() .default(60_000 * 5), // 5 minutes + BATCH_TRIGGER_CACHED_RUNS_CHECK_ENABLED: BoolEnv.default(false), + BATCH_TRIGGER_WORKER_ENABLED: z.string().default(process.env.WORKER_ENABLED ?? "true"), BATCH_TRIGGER_WORKER_CONCURRENCY_WORKERS: z.coerce.number().int().default(2), BATCH_TRIGGER_WORKER_CONCURRENCY_TASKS_PER_WORKER: z.coerce.number().int().default(10), diff --git a/apps/webapp/app/routes/resources.orgs.$organizationSlug.select-plan.tsx b/apps/webapp/app/routes/resources.orgs.$organizationSlug.select-plan.tsx index 90095f342c..8299d775f2 100644 --- a/apps/webapp/app/routes/resources.orgs.$organizationSlug.select-plan.tsx +++ b/apps/webapp/app/routes/resources.orgs.$organizationSlug.select-plan.tsx @@ -42,6 +42,7 @@ import { redirectWithErrorMessage } from "~/models/message.server"; import { logger } from "~/services/logger.server"; import { setPlan } from "~/services/platform.v3.server"; import { requireUser } from "~/services/session.server"; +import { engine } from "~/v3/runEngine.server"; import { cn } from "~/utils/cn"; import { sendToPlain } from "~/utils/plain.server"; @@ -152,7 +153,9 @@ export async function action({ request, params }: ActionFunctionArgs) { } } - return setPlan(organization, request, form.callerPath, payload); + return setPlan(organization, request, form.callerPath, payload, { + invalidateBillingCache: engine.invalidateBillingCache.bind(engine), + }); } const pricingDefinitions = { diff --git a/apps/webapp/app/runEngine/concerns/queues.server.ts b/apps/webapp/app/runEngine/concerns/queues.server.ts index 60cf20b14f..0e213a58d2 100644 --- a/apps/webapp/app/runEngine/concerns/queues.server.ts +++ b/apps/webapp/app/runEngine/concerns/queues.server.ts @@ -177,8 +177,11 @@ export class DefaultQueueManager implements QueueManager { return task.queue.name ?? defaultQueueName; } - async validateQueueLimits(environment: AuthenticatedEnvironment): Promise { - const queueSizeGuard = await guardQueueSizeLimitsForEnv(this.engine, environment); + async validateQueueLimits( + environment: AuthenticatedEnvironment, + itemsToAdd?: number + ): Promise { + const queueSizeGuard = await guardQueueSizeLimitsForEnv(this.engine, environment, itemsToAdd); logger.debug("Queue size guard result", { queueSizeGuard, diff --git a/apps/webapp/app/runEngine/services/batchTrigger.server.ts b/apps/webapp/app/runEngine/services/batchTrigger.server.ts index beadcc9cf7..21893948d4 100644 --- a/apps/webapp/app/runEngine/services/batchTrigger.server.ts +++ b/apps/webapp/app/runEngine/services/batchTrigger.server.ts @@ -1,24 +1,25 @@ import { - BatchTriggerTaskV2RequestBody, - BatchTriggerTaskV3RequestBody, - BatchTriggerTaskV3Response, - IOPacket, + type BatchTriggerTaskV2RequestBody, + type BatchTriggerTaskV3RequestBody, + type BatchTriggerTaskV3Response, + type IOPacket, packetRequiresOffloading, parsePacket, } from "@trigger.dev/core/v3"; import { BatchId, RunId } from "@trigger.dev/core/v3/isomorphic"; -import { BatchTaskRun, Prisma } from "@trigger.dev/database"; +import { type BatchTaskRun, Prisma } from "@trigger.dev/database"; import { Evt } from "evt"; import { z } from "zod"; -import { prisma, PrismaClientOrTransaction } from "~/db.server"; +import { prisma, type PrismaClientOrTransaction } from "~/db.server"; import { env } from "~/env.server"; -import { AuthenticatedEnvironment } from "~/services/apiAuth.server"; +import type { AuthenticatedEnvironment } from "~/services/apiAuth.server"; import { logger } from "~/services/logger.server"; -import { getEntitlement } from "~/services/platform.v3.server"; import { batchTriggerWorker } from "~/v3/batchTriggerWorker.server"; +import { DefaultQueueManager } from "../concerns/queues.server"; +import { DefaultTriggerTaskValidator } from "../validators/triggerTaskValidator"; import { downloadPacketFromObjectStore, uploadPacketToObjectStore } from "../../v3/r2.server"; import { ServiceValidationError, WithRunEngine } from "../../v3/services/baseService.server"; -import { OutOfEntitlementError, TriggerTaskService } from "../../v3/services/triggerTask.server"; +import { TriggerTaskService } from "../../v3/services/triggerTask.server"; import { startActiveSpan } from "../../v3/tracer.server"; const PROCESSING_BATCH_SIZE = 50; @@ -36,6 +37,7 @@ export const BatchProcessingOptions = z.object({ strategy: BatchProcessingStrategy, parentRunId: z.string().optional(), resumeParentOnCompletion: z.boolean().optional(), + planType: z.string().optional(), }); export type BatchProcessingOptions = z.infer; @@ -53,6 +55,8 @@ export type BatchTriggerTaskServiceOptions = { export class RunEngineBatchTriggerService extends WithRunEngine { private _batchProcessingStrategy: BatchProcessingStrategy; public onBatchTaskRunCreated: Evt = new Evt(); + private readonly queueConcern: DefaultQueueManager; + private readonly validator: DefaultTriggerTaskValidator; constructor( batchProcessingStrategy?: BatchProcessingStrategy, @@ -60,6 +64,9 @@ export class RunEngineBatchTriggerService extends WithRunEngine { ) { super({ prisma }); + this.queueConcern = new DefaultQueueManager(this._prisma, this._engine); + this.validator = new DefaultTriggerTaskValidator(); + // Eric note: We need to force sequential processing because when doing parallel, we end up with high-contention on the parent run lock // becuase we are triggering a lot of runs at once, and each one is trying to lock the parent run. // by forcing sequential, we are only ever locking the parent run for a single run at a time. @@ -80,13 +87,18 @@ export class RunEngineBatchTriggerService extends WithRunEngine { span.setAttribute("batchId", friendlyId); - if (environment.type !== "DEVELOPMENT") { - const result = await getEntitlement(environment.organizationId); - if (result && result.hasAccess === false) { - throw new OutOfEntitlementError(); - } + // Validate entitlement and extract planType for batch runs + const entitlementValidation = await this.validator.validateEntitlement({ + environment, + }); + + if (!entitlementValidation.ok) { + throw entitlementValidation.error; } + // Extract plan type from entitlement response + const planType = entitlementValidation.plan?.type; + // Upload to object store const payloadPacket = await this.#handlePayloadPacket( body.items, @@ -99,7 +111,8 @@ export class RunEngineBatchTriggerService extends WithRunEngine { payloadPacket, environment, body, - options + options, + planType ); if (!batch) { @@ -152,7 +165,8 @@ export class RunEngineBatchTriggerService extends WithRunEngine { payloadPacket: IOPacket, environment: AuthenticatedEnvironment, body: BatchTriggerTaskV2RequestBody, - options: BatchTriggerTaskServiceOptions = {} + options: BatchTriggerTaskServiceOptions = {}, + planType?: string ) { if (body.items.length <= ASYNC_BATCH_PROCESS_SIZE_THRESHOLD) { const batch = await this._prisma.batchTaskRun.create({ @@ -191,6 +205,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { options, parentRunId: body.parentRunId, resumeParentOnCompletion: body.resumeParentOnCompletion, + planType, }); switch (result.status) { @@ -220,6 +235,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { strategy: "sequential", parentRunId: body.parentRunId, resumeParentOnCompletion: body.resumeParentOnCompletion, + planType, }); return batch; @@ -242,6 +258,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { strategy: "sequential", parentRunId: body.parentRunId, resumeParentOnCompletion: body.resumeParentOnCompletion, + planType, }); return batch; @@ -285,6 +302,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { strategy: this._batchProcessingStrategy, parentRunId: body.parentRunId, resumeParentOnCompletion: body.resumeParentOnCompletion, + planType, }); break; @@ -307,6 +325,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { strategy: this._batchProcessingStrategy, parentRunId: body.parentRunId, resumeParentOnCompletion: body.resumeParentOnCompletion, + planType, }) ) ); @@ -410,6 +429,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { options: $options, parentRunId: options.parentRunId, resumeParentOnCompletion: options.resumeParentOnCompletion, + planType: options.planType, }); switch (result.status) { @@ -443,6 +463,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { strategy: options.strategy, parentRunId: options.parentRunId, resumeParentOnCompletion: options.resumeParentOnCompletion, + planType: options.planType, }); } @@ -470,6 +491,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { strategy: options.strategy, parentRunId: options.parentRunId, resumeParentOnCompletion: options.resumeParentOnCompletion, + planType: options.planType, }); } else { await this.#enqueueBatchTaskRun({ @@ -486,6 +508,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { strategy: options.strategy, parentRunId: options.parentRunId, resumeParentOnCompletion: options.resumeParentOnCompletion, + planType: options.planType, }); } @@ -503,6 +526,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { options, parentRunId, resumeParentOnCompletion, + planType, }: { batch: BatchTaskRun; environment: AuthenticatedEnvironment; @@ -512,6 +536,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { options?: BatchTriggerTaskServiceOptions; parentRunId?: string | undefined; resumeParentOnCompletion?: boolean | undefined; + planType?: string; }): Promise< | { status: "COMPLETE" } | { status: "INCOMPLETE"; workingIndex: number } @@ -520,6 +545,35 @@ export class RunEngineBatchTriggerService extends WithRunEngine { // Grab the next PROCESSING_BATCH_SIZE items const itemsToProcess = items.slice(currentIndex, currentIndex + batchSize); + const newRunCount = await this.#countNewRuns(environment, itemsToProcess); + + // Only validate queue size if we have new runs to create, i.e. they're not all cached + if (newRunCount > 0) { + const queueSizeGuard = await this.queueConcern.validateQueueLimits(environment, newRunCount); + + logger.debug("Queue size guard result for chunk", { + batchId: batch.friendlyId, + currentIndex, + runCount: batch.runCount, + newRunCount, + queueSizeGuard, + }); + + if (!queueSizeGuard.ok) { + return { + status: "ERROR", + error: `Cannot trigger ${newRunCount} new tasks as the queue size limit for this environment has been reached. The maximum size is ${queueSizeGuard.maximumSize}`, + workingIndex: currentIndex, + }; + } + } else { + logger.debug("[RunEngineBatchTrigger][processBatchTaskRun] All runs are cached", { + batchId: batch.friendlyId, + currentIndex, + runCount: batch.runCount, + }); + } + logger.debug("[RunEngineBatchTrigger][processBatchTaskRun] Processing batch items", { batchId: batch.friendlyId, currentIndex, @@ -540,6 +594,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { options, parentRunId, resumeParentOnCompletion, + planType, }); if (!run) { @@ -615,6 +670,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { options, parentRunId, resumeParentOnCompletion, + planType, }: { batch: BatchTaskRun; environment: AuthenticatedEnvironment; @@ -623,6 +679,7 @@ export class RunEngineBatchTriggerService extends WithRunEngine { options?: BatchTriggerTaskServiceOptions; parentRunId: string | undefined; resumeParentOnCompletion: boolean | undefined; + planType?: string; }) { logger.debug("[RunEngineBatchTrigger][processBatchTaskRunItem] Processing item", { batchId: batch.friendlyId, @@ -649,6 +706,8 @@ export class RunEngineBatchTriggerService extends WithRunEngine { spanParentAsLink: options?.spanParentAsLink, batchId: batch.id, batchIndex: currentIndex, + skipChecks: true, // Skip entitlement and queue checks since we already validated at batch/chunk level + planType, // Pass planType from batch-level entitlement check }, "V2" ); @@ -691,4 +750,85 @@ export class RunEngineBatchTriggerService extends WithRunEngine { }; }); } + + #groupItemsByTaskIdentifier( + items: BatchTriggerTaskV2RequestBody["items"] + ): Record { + return items.reduce((acc, item) => { + if (!item.options?.idempotencyKey) return acc; + + if (!acc[item.task]) { + acc[item.task] = []; + } + acc[item.task].push(item); + return acc; + }, {} as Record); + } + + async #countNewRuns( + environment: AuthenticatedEnvironment, + items: BatchTriggerTaskV2RequestBody["items"] + ): Promise { + // If cached runs check is disabled, return the total number of items + if (!env.BATCH_TRIGGER_CACHED_RUNS_CHECK_ENABLED) { + return items.length; + } + + // Group items by taskIdentifier for efficient lookup + const itemsByTask = this.#groupItemsByTaskIdentifier(items); + + // If no items have idempotency keys, all are new runs + if (Object.keys(itemsByTask).length === 0) { + return items.length; + } + + // Fetch cached runs for each task identifier separately to make use of the index + const cachedRuns = await Promise.all( + Object.entries(itemsByTask).map(([taskIdentifier, taskItems]) => + this._prisma.taskRun.findMany({ + where: { + runtimeEnvironmentId: environment.id, + taskIdentifier, + idempotencyKey: { + in: taskItems.map((i) => i.options?.idempotencyKey).filter(Boolean), + }, + }, + select: { + idempotencyKey: true, + idempotencyKeyExpiresAt: true, + }, + }) + ) + ).then((results) => results.flat()); + + // Create a Map for O(1) lookups instead of O(m) find operations + const cachedRunsMap = new Map(cachedRuns.map((run) => [run.idempotencyKey, run])); + + // Count items that are NOT cached (or have expired cache) + let newRunCount = 0; + const now = new Date(); + + for (const item of items) { + const idempotencyKey = item.options?.idempotencyKey; + + if (!idempotencyKey) { + // No idempotency key = always a new run + newRunCount++; + continue; + } + + const cachedRun = cachedRunsMap.get(idempotencyKey); + + if (!cachedRun) { + // No cached run = new run + newRunCount++; + } else if (cachedRun.idempotencyKeyExpiresAt && cachedRun.idempotencyKeyExpiresAt < now) { + // Expired cached run = new run + newRunCount++; + } + // else: valid cached run = not a new run + } + + return newRunCount; + } } diff --git a/apps/webapp/app/runEngine/services/triggerTask.server.ts b/apps/webapp/app/runEngine/services/triggerTask.server.ts index 5ec3f29dd8..4eece3b939 100644 --- a/apps/webapp/app/runEngine/services/triggerTask.server.ts +++ b/apps/webapp/app/runEngine/services/triggerTask.server.ts @@ -123,13 +123,35 @@ export class RunEngineTriggerTaskService { throw tagValidation.error; } - // Validate entitlement - const entitlementValidation = await this.validator.validateEntitlement({ - environment, - }); + // Validate entitlement (unless skipChecks is enabled) + let planType: string | undefined; + + if (!options.skipChecks) { + const entitlementValidation = await this.validator.validateEntitlement({ + environment, + }); - if (!entitlementValidation.ok) { - throw entitlementValidation.error; + if (!entitlementValidation.ok) { + throw entitlementValidation.error; + } + + // Extract plan type from entitlement response + planType = entitlementValidation.plan?.type; + } else { + // When skipChecks is enabled, planType should be passed via options + planType = options.planType; + + if (!planType) { + logger.warn("Plan type not set but skipChecks is enabled", { + taskId, + environment: { + id: environment.id, + type: environment.type, + projectId: environment.projectId, + organizationId: environment.organizationId, + }, + }); + } } const [parseDelayError, delayUntil] = await tryCatch(parseDelay(body.options?.delay)); @@ -313,6 +335,7 @@ export class RunEngineTriggerTaskService { scheduleInstanceId: options.scheduleInstanceId, createdAt: options.overrideCreatedAt, bulkActionId: body.options?.bulkActionId, + planType, }, this.prisma ); diff --git a/apps/webapp/app/runEngine/types.ts b/apps/webapp/app/runEngine/types.ts index 40a70678e0..b1aa8b7715 100644 --- a/apps/webapp/app/runEngine/types.ts +++ b/apps/webapp/app/runEngine/types.ts @@ -1,12 +1,7 @@ -import { BackgroundWorker, TaskRun } from "@trigger.dev/database"; - -import { - IOPacket, - RunChainState, - TaskRunError, - TriggerTaskRequestBody, -} from "@trigger.dev/core/v3"; -import { AuthenticatedEnvironment } from "~/services/apiAuth.server"; +import type { BackgroundWorker, TaskRun } from "@trigger.dev/database"; +import type { IOPacket, TaskRunError, TriggerTaskRequestBody } from "@trigger.dev/core/v3"; +import type { AuthenticatedEnvironment } from "~/services/apiAuth.server"; +import type { ReportUsagePlan } from "@trigger.dev/platform"; export type TriggerTaskServiceOptions = { idempotencyKey?: string; @@ -22,6 +17,7 @@ export type TriggerTaskServiceOptions = { skipChecks?: boolean; oneTimeUseToken?: string; overrideCreatedAt?: Date; + planType?: string; }; // domain/triggerTask.ts @@ -66,7 +62,10 @@ export interface QueueManager { lockedBackgroundWorker?: LockedBackgroundWorker ): Promise; getQueueName(request: TriggerTaskRequest): Promise; - validateQueueLimits(env: AuthenticatedEnvironment): Promise; + validateQueueLimits( + env: AuthenticatedEnvironment, + itemsToAdd?: number + ): Promise; getWorkerQueue( env: AuthenticatedEnvironment, regionOverride?: string @@ -112,9 +111,19 @@ export type ValidationResult = error: Error; }; +export type EntitlementValidationResult = + | { + ok: true; + plan?: ReportUsagePlan; + } + | { + ok: false; + error: Error; + }; + export interface TriggerTaskValidator { validateTags(params: TagValidationParams): ValidationResult; - validateEntitlement(params: EntitlementValidationParams): Promise; + validateEntitlement(params: EntitlementValidationParams): Promise; validateMaxAttempts(params: MaxAttemptsValidationParams): ValidationResult; validateParentRun(params: ParentRunValidationParams): ValidationResult; } diff --git a/apps/webapp/app/runEngine/validators/triggerTaskValidator.ts b/apps/webapp/app/runEngine/validators/triggerTaskValidator.ts index e63bdacfb5..93eb22258c 100644 --- a/apps/webapp/app/runEngine/validators/triggerTaskValidator.ts +++ b/apps/webapp/app/runEngine/validators/triggerTaskValidator.ts @@ -4,8 +4,9 @@ import { getEntitlement } from "~/services/platform.v3.server"; import { MAX_ATTEMPTS, OutOfEntitlementError } from "~/v3/services/triggerTask.server"; import { isFinalRunStatus } from "~/v3/taskStatus"; import { EngineServiceValidationError } from "../concerns/errors"; -import { +import type { EntitlementValidationParams, + EntitlementValidationResult, MaxAttemptsValidationParams, ParentRunValidationParams, TagValidationParams, @@ -37,7 +38,9 @@ export class DefaultTriggerTaskValidator implements TriggerTaskValidator { return { ok: true }; } - async validateEntitlement(params: EntitlementValidationParams): Promise { + async validateEntitlement( + params: EntitlementValidationParams + ): Promise { const { environment } = params; if (environment.type === "DEVELOPMENT") { @@ -53,7 +56,7 @@ export class DefaultTriggerTaskValidator implements TriggerTaskValidator { }; } - return { ok: true }; + return { ok: true, plan: result?.plan }; } validateMaxAttempts(params: MaxAttemptsValidationParams): ValidationResult { diff --git a/apps/webapp/app/services/platform.v3.server.ts b/apps/webapp/app/services/platform.v3.server.ts index 138fa287dc..1263b864e1 100644 --- a/apps/webapp/app/services/platform.v3.server.ts +++ b/apps/webapp/app/services/platform.v3.server.ts @@ -10,6 +10,8 @@ import { type MachineCode, type UpdateBillingAlertsRequest, type BillingAlertsResult, + type ReportUsageResult, + type ReportUsagePlan, } from "@trigger.dev/platform"; import { createCache, DefaultStatefulContext, Namespace } from "@unkey/cache"; import { MemoryStore } from "@unkey/cache/stores"; @@ -285,7 +287,8 @@ export async function setPlan( organization: { id: string; slug: string }, request: Request, callerPath: string, - plan: SetPlanBody + plan: SetPlanBody, + opts?: { invalidateBillingCache?: (orgId: string) => void } ) { if (!client) { throw redirectWithErrorMessage(callerPath, request, "Error setting plan"); @@ -308,6 +311,8 @@ export async function setPlan( } case "free_connected": { if (result.accepted) { + // Invalidate billing cache since plan changed + opts?.invalidateBillingCache?.(organization.id); return redirect(newProjectPath(organization, "You're on the Free plan.")); } else { return redirectWithErrorMessage( @@ -321,6 +326,8 @@ export async function setPlan( return redirect(result.checkoutUrl); } case "updated_subscription": { + // Invalidate billing cache since subscription changed + opts?.invalidateBillingCache?.(organization.id); return redirectWithSuccessMessage( callerPath, request, @@ -328,6 +335,8 @@ export async function setPlan( ); } case "canceled_subscription": { + // Invalidate billing cache since subscription was canceled + opts?.invalidateBillingCache?.(organization.id); return redirectWithSuccessMessage(callerPath, request, "Subscription canceled."); } } @@ -425,7 +434,9 @@ export async function reportComputeUsage(request: Request) { }); } -export async function getEntitlement(organizationId: string) { +export async function getEntitlement( + organizationId: string +): Promise { if (!client) return undefined; try { diff --git a/apps/webapp/app/utils/boolEnv.ts b/apps/webapp/app/utils/boolEnv.ts index a2609034e3..824292447e 100644 --- a/apps/webapp/app/utils/boolEnv.ts +++ b/apps/webapp/app/utils/boolEnv.ts @@ -1,9 +1,14 @@ import { z } from "zod"; -export const BoolEnv = z.preprocess((val) => { +const baseBoolEnv = z.preprocess((val) => { if (typeof val !== "string") { return val; } return ["true", "1"].includes(val.toLowerCase().trim()); }, z.boolean()); + +// Create a type-safe version that only accepts boolean defaults +export const BoolEnv = baseBoolEnv as Omit & { + default: (value: boolean) => z.ZodDefault; +}; diff --git a/apps/webapp/app/v3/runEngine.server.ts b/apps/webapp/app/v3/runEngine.server.ts index 30f344e724..730e156e0d 100644 --- a/apps/webapp/app/v3/runEngine.server.ts +++ b/apps/webapp/app/v3/runEngine.server.ts @@ -1,7 +1,7 @@ import { RunEngine } from "@internal/run-engine"; import { $replica, prisma } from "~/db.server"; import { env } from "~/env.server"; -import { defaultMachine } from "~/services/platform.v3.server"; +import { defaultMachine, getCurrentPlan } from "~/services/platform.v3.server"; import { singleton } from "~/utils/singleton"; import { allMachines } from "./machinePresets.server"; import { meter, tracer } from "./tracer.server"; @@ -105,6 +105,30 @@ function createRunEngine() { SUSPENDED: env.RUN_ENGINE_TIMEOUT_SUSPENDED, }, retryWarmStartThresholdMs: env.RUN_ENGINE_RETRY_WARM_START_THRESHOLD_MS, + billing: { + getCurrentPlan: async (orgId: string) => { + const plan = await getCurrentPlan(orgId); + + if (!plan) { + return { + isPaying: false, + type: "free", + }; + } + + if (!plan.v3Subscription) { + return { + isPaying: false, + type: "free", + }; + } + + return { + isPaying: plan.v3Subscription.isPaying, + type: plan.v3Subscription.plan?.type ?? "free", + }; + }, + }, }); return engine; diff --git a/apps/webapp/app/v3/services/triggerTask.server.ts b/apps/webapp/app/v3/services/triggerTask.server.ts index bfb31e2499..f2e0d3c08a 100644 --- a/apps/webapp/app/v3/services/triggerTask.server.ts +++ b/apps/webapp/app/v3/services/triggerTask.server.ts @@ -33,6 +33,7 @@ export type TriggerTaskServiceOptions = { queueTimestamp?: Date; overrideCreatedAt?: Date; replayedFromTaskRunFriendlyId?: string; + planType?: string; }; export class OutOfEntitlementError extends Error { diff --git a/apps/webapp/package.json b/apps/webapp/package.json index d01407d17d..b6341aca5f 100644 --- a/apps/webapp/package.json +++ b/apps/webapp/package.json @@ -113,7 +113,7 @@ "@trigger.dev/core": "workspace:*", "@trigger.dev/database": "workspace:*", "@trigger.dev/otlp-importer": "workspace:*", - "@trigger.dev/platform": "1.0.17", + "@trigger.dev/platform": "1.0.18", "@trigger.dev/redis-worker": "workspace:*", "@trigger.dev/sdk": "workspace:*", "@types/pg": "8.6.6", diff --git a/internal-packages/cache/src/index.ts b/internal-packages/cache/src/index.ts index e5844d910b..d378191d29 100644 --- a/internal-packages/cache/src/index.ts +++ b/internal-packages/cache/src/index.ts @@ -3,6 +3,8 @@ export { DefaultStatefulContext, Namespace, type Cache as UnkeyCache, + type CacheError, } from "@unkey/cache"; +export { type Result, Ok, Err } from "@unkey/error"; export { MemoryStore } from "@unkey/cache/stores"; export { RedisCacheStore } from "./stores/redis.js"; diff --git a/internal-packages/database/prisma/migrations/20250814092224_add_task_run_plan_type/migration.sql b/internal-packages/database/prisma/migrations/20250814092224_add_task_run_plan_type/migration.sql new file mode 100644 index 0000000000..2513c02622 --- /dev/null +++ b/internal-packages/database/prisma/migrations/20250814092224_add_task_run_plan_type/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "TaskRun" ADD COLUMN "planType" TEXT; \ No newline at end of file diff --git a/internal-packages/database/prisma/schema.prisma b/internal-packages/database/prisma/schema.prisma index ba69f0f04f..1f0397904d 100644 --- a/internal-packages/database/prisma/schema.prisma +++ b/internal-packages/database/prisma/schema.prisma @@ -707,6 +707,9 @@ model TaskRun { /// Run error error Json? + /// Organization's billing plan type (cached for fallback when billing API fails) + planType String? + maxDurationInSeconds Int? @@unique([oneTimeUseToken]) diff --git a/internal-packages/run-engine/src/engine/billingCache.ts b/internal-packages/run-engine/src/engine/billingCache.ts new file mode 100644 index 0000000000..45fd3dc382 --- /dev/null +++ b/internal-packages/run-engine/src/engine/billingCache.ts @@ -0,0 +1,92 @@ +import { + createCache, + DefaultStatefulContext, + MemoryStore, + Namespace, + Ok, + RedisCacheStore, + type UnkeyCache, + type CacheError, + type Result, +} from "@internal/cache"; +import type { RedisOptions } from "@internal/redis"; +import type { Logger } from "@trigger.dev/core/logger"; +import type { RunEngineOptions } from "./types.js"; + +// Cache TTLs for billing information - shorter than other caches since billing can change +const BILLING_FRESH_TTL = 60000 * 5; // 5 minutes +const BILLING_STALE_TTL = 60000 * 10; // 10 minutes + +export type BillingPlan = { + isPaying: boolean; + type: "free" | "paid" | "enterprise"; +}; + +export type BillingCacheOptions = { + billingOptions?: RunEngineOptions["billing"]; + redisOptions: RedisOptions; + logger: Logger; +}; + +export class BillingCache { + private readonly cache: UnkeyCache<{ + currentPlan: BillingPlan; + }>; + private readonly logger: Logger; + private readonly billingOptions?: RunEngineOptions["billing"]; + + constructor(options: BillingCacheOptions) { + this.logger = options.logger; + this.billingOptions = options.billingOptions; + + // Initialize cache + const ctx = new DefaultStatefulContext(); + const memory = new MemoryStore({ persistentMap: new Map() }); + const redisCacheStore = new RedisCacheStore({ + name: "billing-cache", + connection: { + ...options.redisOptions, + keyPrefix: "engine:billing:cache:", + }, + useModernCacheKeyBuilder: true, + }); + + this.cache = createCache({ + currentPlan: new Namespace(ctx, { + stores: [memory, redisCacheStore], + fresh: BILLING_FRESH_TTL, + stale: BILLING_STALE_TTL, + }), + }); + } + + /** + * Gets the current billing plan for an organization + * Returns a Result that allows the caller to handle errors and missing values + */ + async getCurrentPlan(orgId: string): Promise> { + if (!this.billingOptions?.getCurrentPlan) { + // Return a successful result with default free plan + return Ok({ isPaying: false, type: "free" }); + } + + return await this.cache.currentPlan.swr(orgId, async () => { + // This is safe because options can't change at runtime + const planResult = await this.billingOptions!.getCurrentPlan(orgId); + return { isPaying: planResult.isPaying, type: planResult.type }; + }); + } + + /** + * Invalidates the billing cache for an organization when their plan changes + * Runs in background and handles all errors internally + */ + invalidate(orgId: string): void { + this.cache.currentPlan.remove(orgId).catch((error) => { + this.logger.warn("Failed to invalidate billing cache", { + orgId, + error: error instanceof Error ? error.message : String(error), + }); + }); + } +} diff --git a/internal-packages/run-engine/src/engine/index.ts b/internal-packages/run-engine/src/engine/index.ts index 20ef7d53e8..f71c62dcd3 100644 --- a/internal-packages/run-engine/src/engine/index.ts +++ b/internal-packages/run-engine/src/engine/index.ts @@ -1,3 +1,4 @@ +import { BillingCache } from "./billingCache.js"; import { createRedisClient, Redis } from "@internal/redis"; import { getMeter, Meter, startSpan, trace, Tracer } from "@internal/tracing"; import { Logger } from "@trigger.dev/core/logger"; @@ -78,6 +79,8 @@ export class RunEngine { pendingVersionSystem: PendingVersionSystem; raceSimulationSystem: RaceSimulationSystem = new RaceSimulationSystem(); + private readonly billingCache: BillingCache; + constructor(private readonly options: RunEngineOptions) { this.logger = options.logger ?? new Logger("RunEngine", this.options.logLevel ?? "info"); this.prisma = options.prisma; @@ -292,11 +295,18 @@ export class RunEngine { redisOptions: this.options.cache?.redis ?? this.options.runLock.redis, }); + this.billingCache = new BillingCache({ + billingOptions: this.options.billing, + redisOptions: this.options.cache?.redis ?? this.options.runLock.redis, + logger: this.logger, + }); + this.dequeueSystem = new DequeueSystem({ resources, executionSnapshotSystem: this.executionSnapshotSystem, runAttemptSystem: this.runAttemptSystem, machines: this.options.machines, + billingCache: this.billingCache, }); } @@ -354,6 +364,7 @@ export class RunEngine { scheduleInstanceId, createdAt, bulkActionId, + planType, }: TriggerParams, tx?: PrismaClientOrTransaction ): Promise { @@ -429,6 +440,7 @@ export class RunEngine { scheduleInstanceId, createdAt, bulkActionGroupIds: bulkActionId ? [bulkActionId] : undefined, + planType, executionSnapshots: { create: { engine: "V2", @@ -1347,4 +1359,12 @@ export class RunEngine { orgId: run.organizationId!, })); } + + /** + * Invalidates the billing cache for an organization when their plan changes + * Runs in background and handles all errors internally + */ + invalidateBillingCache(orgId: string): void { + this.billingCache.invalidate(orgId); + } } diff --git a/internal-packages/run-engine/src/engine/systems/dequeueSystem.ts b/internal-packages/run-engine/src/engine/systems/dequeueSystem.ts index 0e1319a5d6..4e41e8fc6f 100644 --- a/internal-packages/run-engine/src/engine/systems/dequeueSystem.ts +++ b/internal-packages/run-engine/src/engine/systems/dequeueSystem.ts @@ -1,6 +1,7 @@ +import type { BillingCache } from "../billingCache.js"; import { startSpan } from "@internal/tracing"; import { assertExhaustive } from "@trigger.dev/core"; -import { DequeuedMessage, RetryOptions } from "@trigger.dev/core/v3"; +import { DequeuedMessage, RetryOptions, placementTag } from "@trigger.dev/core/v3"; import { getMaxDuration } from "@trigger.dev/core/v3/isomorphic"; import { PrismaClientOrTransaction } from "@trigger.dev/database"; import { getRunWithBackgroundWorkerTasks } from "../db/worker.js"; @@ -17,6 +18,7 @@ export type DequeueSystemOptions = { machines: RunEngineOptions["machines"]; executionSnapshotSystem: ExecutionSnapshotSystem; runAttemptSystem: RunAttemptSystem; + billingCache: BillingCache; }; export class DequeueSystem { @@ -380,6 +382,30 @@ export class DequeueSystem { const currentAttemptNumber = lockedTaskRun.attemptNumber ?? 0; const nextAttemptNumber = currentAttemptNumber + 1; + // Get billing information if available, with fallback to TaskRun.planType + const billingResult = await this.options.billingCache.getCurrentPlan(orgId); + + let isPaying: boolean; + if (billingResult.err || !billingResult.val) { + // Fallback to stored planType on TaskRun if billing cache fails or returns no value + this.$.logger.warn( + "Billing cache failed or returned no value, falling back to TaskRun.planType", + { + orgId, + runId, + error: + billingResult.err instanceof Error + ? billingResult.err.message + : String(billingResult.err), + currentPlan: billingResult.val, + } + ); + + isPaying = (lockedTaskRun.planType ?? "free") !== "free"; + } else { + isPaying = billingResult.val.isPaying; + } + const newSnapshot = await this.executionSnapshotSystem.createExecutionSnapshot( prisma, { @@ -448,6 +474,7 @@ export class DequeueSystem { project: { id: lockedTaskRun.projectId, }, + placementTags: [placementTag("paid", isPaying ? "true" : "false")], } satisfies DequeuedMessage; } ); diff --git a/internal-packages/run-engine/src/engine/types.ts b/internal-packages/run-engine/src/engine/types.ts index 7f22b6770d..9abf2acfa2 100644 --- a/internal-packages/run-engine/src/engine/types.ts +++ b/internal-packages/run-engine/src/engine/types.ts @@ -5,7 +5,6 @@ import { MachinePreset, MachinePresetName, RetryOptions, - RunChainState, TriggerTraceContext, } from "@trigger.dev/core/v3"; import { PrismaClient, PrismaReplicaClient } from "@trigger.dev/database"; @@ -14,6 +13,7 @@ import { FairQueueSelectionStrategyOptions } from "../run-queue/fairQueueSelecti import { MinimalAuthenticatedEnvironment } from "../shared/index.js"; import { LockRetryConfig } from "./locking.js"; import { workerCatalog } from "./workerCatalog.js"; +import { type BillingPlan } from "./billingCache.js"; export type RunEngineOptions = { prisma: PrismaClient; @@ -30,6 +30,9 @@ export type RunEngineOptions = { machines: Record; baseCostInCents: number; }; + billing?: { + getCurrentPlan: (orgId: string) => Promise; + }; queue: { redis: RedisOptions; shardCount?: number; @@ -133,6 +136,7 @@ export type TriggerParams = { scheduleInstanceId?: string; createdAt?: Date; bulkActionId?: string; + planType?: string; }; export type EngineWorker = Worker; diff --git a/packages/core/src/v3/schemas/runEngine.ts b/packages/core/src/v3/schemas/runEngine.ts index ef6ef170ce..5bfd8fe1d7 100644 --- a/packages/core/src/v3/schemas/runEngine.ts +++ b/packages/core/src/v3/schemas/runEngine.ts @@ -224,6 +224,17 @@ export const DequeueMessageCheckpoint = z.object({ }); export type DequeueMessageCheckpoint = z.infer; +export const PlacementTag = z.object({ + key: z.string(), + values: z.array(z.string()).optional(), +}); +export type PlacementTag = z.infer; + +/** Helper functions for placement tags. In the future this will be able to support multiple values and operators. For now it's just a single value. */ +export function placementTag(key: string, value: string): PlacementTag { + return { key, values: [value] }; +} + /** This is sent to a Worker when a run is dequeued (a new run or continuing run) */ export const DequeuedMessage = z.object({ version: z.literal("1"), @@ -261,5 +272,6 @@ export const DequeuedMessage = z.object({ project: z.object({ id: z.string(), }), + placementTags: z.array(PlacementTag).optional(), }); export type DequeuedMessage = z.infer; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 15490dbedd..a2fe411c9f 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -444,8 +444,8 @@ importers: specifier: workspace:* version: link:../../internal-packages/otlp-importer '@trigger.dev/platform': - specifier: 1.0.17 - version: 1.0.17 + specifier: 1.0.18 + version: 1.0.18 '@trigger.dev/redis-worker': specifier: workspace:* version: link:../../packages/redis-worker @@ -19779,8 +19779,8 @@ packages: react-dom: 18.2.0(react@18.2.0) dev: false - /@trigger.dev/platform@1.0.17: - resolution: {integrity: sha512-cR05nn8HnP03h/bmRN6O/EKgvQncbs3Y/7fp1QboEDWn6rJTRrWJpZVrA3ZQ32SIW1qvHuZLcB1OVaEsJk2wjA==} + /@trigger.dev/platform@1.0.18: + resolution: {integrity: sha512-7huIRYY9+QzoV9b8lIr7GGLhLSrt2mu/LX+aENO2Jch8C0SAKuztBdJk/zi9NXYhmQzkpS2ASWGukf4qOAIwXg==} dependencies: zod: 3.23.8 dev: false