diff --git a/.changeset/fifty-crabs-arrive.md b/.changeset/fifty-crabs-arrive.md new file mode 100644 index 000000000..8a3c84ea4 --- /dev/null +++ b/.changeset/fifty-crabs-arrive.md @@ -0,0 +1,5 @@ +--- +"@browserbasehq/stagehand": patch +--- + +you can now call stagehand.metrics to get token usage metrics. you can also set logInferenceToFile in stagehand config to log the entire call/response history from stagehand & the LLM. diff --git a/lib/StagehandPage.ts b/lib/StagehandPage.ts index 2d581ac00..f7f211cbf 100644 --- a/lib/StagehandPage.ts +++ b/lib/StagehandPage.ts @@ -89,6 +89,7 @@ export class StagehandPage { if (this.llmClient) { this.actHandler = new StagehandActHandler({ + stagehand: this.stagehand, verbose: this.stagehand.verbose, llmProvider: this.stagehand.llmProvider, enableCaching: this.stagehand.enableCaching, diff --git a/lib/handlers/actHandler.ts b/lib/handlers/actHandler.ts index 0c135d0a3..c6da68138 100644 --- a/lib/handlers/actHandler.ts +++ b/lib/handlers/actHandler.ts @@ -16,6 +16,7 @@ import { ObserveResult, ActOptions, ObserveOptions, + StagehandFunctionName, } from "@/types/stagehand"; import { MethodHandlerContext, SupportedPlaywrightAction } from "@/types/act"; import { buildActObservePrompt } from "../prompt"; @@ -23,13 +24,14 @@ import { methodHandlerMap, fallbackLocatorMethod, } from "./handlerUtils/actHandlerUtils"; - +import { Stagehand } from "@/lib"; /** * NOTE: Vision support has been removed from this version of Stagehand. * If useVision or verifierUseVision is set to true, a warning is logged and * the flow continues as if vision = false. */ export class StagehandActHandler { + private readonly stagehand: Stagehand; private readonly stagehandPage: StagehandPage; private readonly verbose: 0 | 1 | 2; private readonly llmProvider: LLMProvider; @@ -44,6 +46,7 @@ export class StagehandActHandler { private readonly waitForCaptchaSolves: boolean; constructor({ + stagehand, verbose, llmProvider, enableCaching, @@ -53,6 +56,7 @@ export class StagehandActHandler { selfHeal, waitForCaptchaSolves, }: { + stagehand: Stagehand; verbose: 0 | 1 | 2; llmProvider: LLMProvider; enableCaching: boolean; @@ -64,6 +68,7 @@ export class StagehandActHandler { selfHeal: boolean; waitForCaptchaSolves: boolean; }) { + this.stagehand = stagehand; this.verbose = verbose; this.llmProvider = llmProvider; this.enableCaching = enableCaching; @@ -337,7 +342,7 @@ export class StagehandActHandler { }); // Always use text-based DOM verification (no vision). - actionCompleted = await verifyActCompletion({ + const verifyResult = await verifyActCompletion({ goal: action, steps, llmProvider: this.llmProvider, @@ -345,7 +350,9 @@ export class StagehandActHandler { domElements, logger: this.logger, requestId, + logInferenceToFile: this.stagehand.logInferenceToFile, }); + actionCompleted = verifyResult.completed; this.logger({ category: "action", @@ -362,6 +369,12 @@ export class StagehandActHandler { }, }, }); + this.stagehand.updateMetrics( + StagehandFunctionName.ACT, + verifyResult.prompt_tokens, + verifyResult.completion_tokens, + verifyResult.inference_time_ms, + ); } return actionCompleted; @@ -681,6 +694,15 @@ export class StagehandActHandler { requestId, variables, userProvidedInstructions: this.userProvidedInstructions, + onActMetrics: (promptTokens, completionTokens, inferenceTimeMs) => { + this.stagehand.updateMetrics( + StagehandFunctionName.ACT, + promptTokens, + completionTokens, + inferenceTimeMs, + ); + }, + logInferenceToFile: this.stagehand.logInferenceToFile, }); this.logger({ diff --git a/lib/handlers/extractHandler.ts b/lib/handlers/extractHandler.ts index b66ebc19c..d57ef5242 100644 --- a/lib/handlers/extractHandler.ts +++ b/lib/handlers/extractHandler.ts @@ -5,7 +5,7 @@ import { extract } from "../inference"; import { LLMClient } from "../llm/LLMClient"; import { formatText } from "../utils"; import { StagehandPage } from "../StagehandPage"; -import { Stagehand } from "../index"; +import { Stagehand, StagehandFunctionName } from "../index"; import { pageTextSchema } from "../../types/page"; const PROXIMITY_THRESHOLD = 15; @@ -353,13 +353,24 @@ export class StagehandExtractHandler { requestId, userProvidedInstructions: this.userProvidedInstructions, logger: this.logger, + logInferenceToFile: this.stagehand.logInferenceToFile, }); const { metadata: { completed }, + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + inference_time_ms: inferenceTimeMs, ...output } = extractionResponse; + this.stagehand.updateMetrics( + StagehandFunctionName.EXTRACT, + promptTokens, + completionTokens, + inferenceTimeMs, + ); + // **11:** Handle the extraction response and log the results this.logger({ category: "extraction", @@ -481,13 +492,24 @@ export class StagehandExtractHandler { isUsingTextExtract: false, userProvidedInstructions: this.userProvidedInstructions, logger: this.logger, + logInferenceToFile: this.stagehand.logInferenceToFile, }); const { metadata: { completed }, + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + inference_time_ms: inferenceTimeMs, ...output } = extractionResponse; + this.stagehand.updateMetrics( + StagehandFunctionName.EXTRACT, + promptTokens, + completionTokens, + inferenceTimeMs, + ); + this.logger({ category: "extraction", message: "received extraction response", diff --git a/lib/handlers/observeHandler.ts b/lib/handlers/observeHandler.ts index 6e42dccfe..38e59ea84 100644 --- a/lib/handlers/observeHandler.ts +++ b/lib/handlers/observeHandler.ts @@ -1,5 +1,5 @@ import { LogLine } from "../../types/log"; -import { Stagehand } from "../index"; +import { Stagehand, StagehandFunctionName } from "../index"; import { observe } from "../inference"; import { LLMClient } from "../llm/LLMClient"; import { StagehandPage } from "../StagehandPage"; @@ -113,8 +113,22 @@ export class StagehandObserveHandler { logger: this.logger, isUsingAccessibilityTree: useAccessibilityTree, returnAction, + logInferenceToFile: this.stagehand.logInferenceToFile, }); + const { + prompt_tokens = 0, + completion_tokens = 0, + inference_time_ms = 0, + } = observationResponse; + + this.stagehand.updateMetrics( + StagehandFunctionName.OBSERVE, + prompt_tokens, + completion_tokens, + inference_time_ms, + ); + //Add iframes to the observation response if there are any on the page if (iframes.length > 0) { iframes.forEach((iframe) => { diff --git a/lib/index.ts b/lib/index.ts index d35782d36..3cdcf2d11 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -26,6 +26,8 @@ import { ObserveOptions, ObserveResult, AgentConfig, + StagehandMetrics, + StagehandFunctionName, } from "../types/stagehand"; import { StagehandContext } from "./StagehandContext"; import { StagehandPage } from "./StagehandPage"; @@ -382,7 +384,7 @@ export class Stagehand { public readonly selfHeal: boolean; private cleanupCalled = false; public readonly actTimeoutMs: number; - + public readonly logInferenceToFile?: boolean; protected setActivePage(page: StagehandPage): void { this.stagehandPage = page; } @@ -396,6 +398,63 @@ export class Stagehand { return this.stagehandPage.page; } + public stagehandMetrics: StagehandMetrics = { + actPromptTokens: 0, + actCompletionTokens: 0, + actInferenceTimeMs: 0, + extractPromptTokens: 0, + extractCompletionTokens: 0, + extractInferenceTimeMs: 0, + observePromptTokens: 0, + observeCompletionTokens: 0, + observeInferenceTimeMs: 0, + totalPromptTokens: 0, + totalCompletionTokens: 0, + totalInferenceTimeMs: 0, + }; + + public get metrics(): StagehandMetrics { + return this.stagehandMetrics; + } + + public updateMetrics( + functionName: StagehandFunctionName, + promptTokens: number, + completionTokens: number, + inferenceTimeMs: number, + ): void { + switch (functionName) { + case StagehandFunctionName.ACT: + this.stagehandMetrics.actPromptTokens += promptTokens; + this.stagehandMetrics.actCompletionTokens += completionTokens; + this.stagehandMetrics.actInferenceTimeMs += inferenceTimeMs; + break; + + case StagehandFunctionName.EXTRACT: + this.stagehandMetrics.extractPromptTokens += promptTokens; + this.stagehandMetrics.extractCompletionTokens += completionTokens; + this.stagehandMetrics.extractInferenceTimeMs += inferenceTimeMs; + break; + + case StagehandFunctionName.OBSERVE: + this.stagehandMetrics.observePromptTokens += promptTokens; + this.stagehandMetrics.observeCompletionTokens += completionTokens; + this.stagehandMetrics.observeInferenceTimeMs += inferenceTimeMs; + break; + } + this.updateTotalMetrics(promptTokens, completionTokens, inferenceTimeMs); + } + + private updateTotalMetrics( + promptTokens: number, + completionTokens: number, + inferenceTimeMs: number, + ): void { + this.stagehandMetrics.totalPromptTokens += promptTokens; + this.stagehandMetrics.totalCompletionTokens += completionTokens; + this.stagehandMetrics.totalInferenceTimeMs += inferenceTimeMs; + } + constructor( { env, @@ -419,6 +478,7 @@ export class Stagehand { selfHeal = true, waitForCaptchaSolves = false, actTimeoutMs = 60_000, + logInferenceToFile = false, }: ConstructorParams = { env: "BROWSERBASE", }, @@ -473,6 +533,7 @@ export class Stagehand { if (this.usingAPI) { this.registerSignalHandlers(); } + this.logInferenceToFile = logInferenceToFile; } private registerSignalHandlers() { diff --git a/lib/inference.ts b/lib/inference.ts index ff0043a6f..100d576b8 100644 --- a/lib/inference.ts +++ b/lib/inference.ts @@ -2,7 +2,7 @@ import { z } from "zod"; import { ActCommandParams, ActCommandResult } from "../types/act"; import { VerifyActCompletionParams } from "../types/inference"; import { LogLine } from "../types/log"; -import { ChatMessage, LLMClient } from "./llm/LLMClient"; +import { ChatMessage, LLMClient, LLMResponse } from "./llm/LLMClient"; import { actTools, buildActSystemPrompt, @@ -18,6 +18,47 @@ import { buildVerifyActCompletionSystemPrompt, buildVerifyActCompletionUserPrompt, } from "./prompt"; +import { + appendSummary, + writeTimestampedTxtFile, +} from "@/lib/inferenceLogUtils"; + +/** + * Replaces <|VARIABLE|> placeholders in a text with user-provided values. + */ +export function fillInVariables( + text: string, + variables: Record, +) { + let processedText = text; + Object.entries(variables).forEach(([key, value]) => { + const placeholder = `<|${key.toUpperCase()}|>`; + processedText = processedText.replace(placeholder, value); + }); + return processedText; +} + +/** Simple usage shape if your LLM returns usage tokens. */ +interface LLMUsage { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; +} + +/** + * For calls that use a schema: the LLMClient may return { data: T; usage?: LLMUsage } + */ +interface LLMParsedResponse { + data: T; + usage?: LLMUsage; +} + +export interface VerifyActCompletionResult { + completed: boolean; + prompt_tokens: number; + completion_tokens: number; + inference_time_ms: number; +} export async function verifyActCompletion({ goal, @@ -26,61 +67,111 @@ export async function verifyActCompletion({ domElements, logger, requestId, -}: VerifyActCompletionParams): Promise { + logInferenceToFile = false, +}: VerifyActCompletionParams & { + logInferenceToFile?: boolean; +}): Promise { const verificationSchema = z.object({ completed: z.boolean().describe("true if the goal is accomplished"), }); - type VerificationResponse = z.infer; - const response = await llmClient.createChatCompletion({ - options: { - messages: [ - buildVerifyActCompletionSystemPrompt(), - buildVerifyActCompletionUserPrompt(goal, steps, domElements), - ], - temperature: 0.1, - top_p: 1, - frequency_penalty: 0, - presence_penalty: 0, - response_model: { - name: "Verification", - schema: verificationSchema, - }, + const messages: ChatMessage[] = [ + buildVerifyActCompletionSystemPrompt(), + buildVerifyActCompletionUserPrompt(goal, steps, domElements), + ]; + + let callFile = ""; + let callTimestamp = ""; + if (logInferenceToFile) { + const callResult = writeTimestampedTxtFile("act_summary", "verify_call", { requestId, - }, - logger, - }); + modelCall: "verifyActCompletion", + messages, + }); + callFile = callResult.fileName; + callTimestamp = callResult.timestamp; + } + + const start = Date.now(); + const rawResponse = + await llmClient.createChatCompletion({ + options: { + messages, + temperature: 0.1, + top_p: 1, + frequency_penalty: 0, + presence_penalty: 0, + response_model: { + name: "Verification", + schema: verificationSchema, + }, + requestId, + }, + logger, + }); + const end = Date.now(); + const inferenceTimeMs = end - start; + + const parsedResponse = rawResponse as LLMParsedResponse; + const verificationData = parsedResponse.data; + const verificationUsage = parsedResponse.usage; + + let responseFile = ""; + if (logInferenceToFile) { + const responseResult = writeTimestampedTxtFile( + "act_summary", + "verify_response", + { + requestId, + modelResponse: "verifyActCompletion", + rawResponse: verificationData, + }, + ); + responseFile = responseResult.fileName; - if (!response || typeof response !== "object") { + appendSummary("act", { + act_inference_type: "verifyActCompletion", + timestamp: callTimestamp, + LLM_input_file: callFile, + LLM_output_file: responseFile, + prompt_tokens: verificationUsage?.prompt_tokens ?? 0, + completion_tokens: verificationUsage?.completion_tokens ?? 0, + inference_time_ms: inferenceTimeMs, + }); + } + + if (!verificationData || typeof verificationData !== "object") { logger({ category: "VerifyAct", - message: "Unexpected response format: " + JSON.stringify(response), + message: "Unexpected response format: " + JSON.stringify(parsedResponse), }); - return false; + return { + completed: false, + prompt_tokens: verificationUsage?.prompt_tokens ?? 0, + completion_tokens: verificationUsage?.completion_tokens ?? 0, + inference_time_ms: inferenceTimeMs, + }; } - - if (response.completed === undefined) { + if (verificationData.completed === undefined) { logger({ category: "VerifyAct", message: "Missing 'completed' field in response", }); - return false; + return { + completed: false, + prompt_tokens: verificationUsage?.prompt_tokens ?? 0, + completion_tokens: verificationUsage?.completion_tokens ?? 0, + inference_time_ms: inferenceTimeMs, + }; } - return response.completed; -} - -export function fillInVariables( - text: string, - variables: Record, -) { - let processedText = text; - Object.entries(variables).forEach(([key, value]) => { - const placeholder = `<|${key.toUpperCase()}|>`; - processedText = processedText.replace(placeholder, value); - }); - return processedText; + return { + completed: verificationData.completed, + prompt_tokens: verificationUsage?.prompt_tokens ?? 0, + completion_tokens: verificationUsage?.completion_tokens ?? 0, + inference_time_ms: inferenceTimeMs, + }; } export async function act({ @@ -93,39 +184,95 @@ export async function act({ requestId, variables, userProvidedInstructions, -}: ActCommandParams): Promise { + onActMetrics, + logInferenceToFile = false, +}: ActCommandParams & { + onActMetrics?: ( + promptTokens: number, + completionTokens: number, + inferenceTimeMs: number, + ) => void; + logInferenceToFile?: boolean; +}): Promise { const messages: ChatMessage[] = [ buildActSystemPrompt(userProvidedInstructions), buildActUserPrompt(action, steps, domElements, variables), ]; - const response = await llmClient.createChatCompletion({ + let callFile = ""; + let callTimestamp = ""; + if (logInferenceToFile) { + const callResult = writeTimestampedTxtFile("act_summary", "act_call", { + requestId, + modelCall: "act", + messages, + }); + callFile = callResult.fileName; + callTimestamp = callResult.timestamp; + } + + const start = Date.now(); + const rawResponse = await llmClient.createChatCompletion({ options: { messages, temperature: 0.1, top_p: 1, frequency_penalty: 0, presence_penalty: 0, - tool_choice: "auto" as const, + tool_choice: "auto", tools: actTools, requestId, }, logger, }); + const end = Date.now(); + const inferenceTimeMs = end - start; + + let responseFile = ""; + if (logInferenceToFile) { + const responseResult = writeTimestampedTxtFile( + "act_summary", + "act_response", + { + requestId, + modelResponse: "act", + rawResponse, + }, + ); + responseFile = responseResult.fileName; + } + + const usageData = rawResponse.usage; + const promptTokens = usageData?.prompt_tokens ?? 0; + const completionTokens = usageData?.completion_tokens ?? 0; - const toolCalls = response.choices[0].message.tool_calls; + if (logInferenceToFile) { + appendSummary("act", { + act_inference_type: "act", + timestamp: callTimestamp, + LLM_input_file: callFile, + LLM_output_file: responseFile, + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + inference_time_ms: inferenceTimeMs, + }); + } + if (onActMetrics) { + onActMetrics(promptTokens, completionTokens, inferenceTimeMs); + } + + const toolCalls = rawResponse.choices?.[0]?.message?.tool_calls; if (toolCalls && toolCalls.length > 0) { if (toolCalls[0].function.name === "skipSection") { return null; } - return JSON.parse(toolCalls[0].function.arguments); } else { if (retries >= 2) { logger({ category: "Act", - message: "No tool calls found in response", + message: "No tool calls found in response after multiple retries.", }); return null; } @@ -138,6 +285,10 @@ export async function act({ retries: retries + 1, logger, requestId, + variables, + userProvidedInstructions, + onActMetrics, + logInferenceToFile, }); } } @@ -154,6 +305,7 @@ export async function extract({ logger, isUsingTextExtract, userProvidedInstructions, + logInferenceToFile = false, }: { instruction: string; previouslyExtractedContent: object; @@ -166,48 +318,129 @@ export async function extract({ isUsingTextExtract?: boolean; userProvidedInstructions?: string; logger: (message: LogLine) => void; + logInferenceToFile?: boolean; }) { + const metadataSchema = z.object({ + progress: z + .string() + .describe( + "progress of what has been extracted so far, as concise as possible", + ), + completed: z + .boolean() + .describe( + "true if the goal is now accomplished. Use this conservatively, only when sure that the goal has been completed.", + ), + }); + type ExtractionResponse = z.infer; type MetadataResponse = z.infer; - // TODO: antipattern + const isUsingAnthropic = llmClient.type === "anthropic"; - const extractionResponse = await llmClient.createChatCompletion({ - options: { - messages: [ - buildExtractSystemPrompt( - isUsingAnthropic, - isUsingTextExtract, - userProvidedInstructions, - ), - buildExtractUserPrompt(instruction, domElements, isUsingAnthropic), - ], - response_model: { - schema: schema, - name: "Extraction", + const extractCallMessages: ChatMessage[] = [ + buildExtractSystemPrompt( + isUsingAnthropic, + isUsingTextExtract, + userProvidedInstructions, + ), + buildExtractUserPrompt(instruction, domElements, isUsingAnthropic), + ]; + + let extractCallFile = ""; + let extractCallTimestamp = ""; + if (logInferenceToFile) { + const { fileName, timestamp } = writeTimestampedTxtFile( + "extract_summary", + "extract_call", + { + requestId, + modelCall: "extract", + messages: extractCallMessages, }, - temperature: 0.1, - top_p: 1, - frequency_penalty: 0, - presence_penalty: 0, - requestId, - }, - logger, - }); + ); + extractCallFile = fileName; + extractCallTimestamp = timestamp; + } + + const extractStartTime = Date.now(); + const extractionResponse = + await llmClient.createChatCompletion({ + options: { + messages: extractCallMessages, + response_model: { + schema, + name: "Extraction", + }, + temperature: 0.1, + top_p: 1, + frequency_penalty: 0, + presence_penalty: 0, + requestId, + }, + logger, + }); + const extractEndTime = Date.now(); + + const { data: extractedData, usage: extractUsage } = + extractionResponse as LLMParsedResponse; + + let extractResponseFile = ""; + if (logInferenceToFile) { + const { fileName } = writeTimestampedTxtFile( + "extract_summary", + "extract_response", + { + requestId, + modelResponse: "extract", + rawResponse: extractedData, + }, + ); + extractResponseFile = fileName; + + appendSummary("extract", { + extract_inference_type: "extract", + timestamp: extractCallTimestamp, + LLM_input_file: extractCallFile, + LLM_output_file: extractResponseFile, + prompt_tokens: extractUsage?.prompt_tokens ?? 0, + completion_tokens: extractUsage?.completion_tokens ?? 0, + inference_time_ms: extractEndTime - extractStartTime, + }); + } + + const refineCallMessages: ChatMessage[] = [ + buildRefineSystemPrompt(), + buildRefineUserPrompt( + instruction, + previouslyExtractedContent, + extractedData, + ), + ]; + let refineCallFile = ""; + let refineCallTimestamp = ""; + if (logInferenceToFile) { + const { fileName, timestamp } = writeTimestampedTxtFile( + "extract_summary", + "refine_call", + { + requestId, + modelCall: "refine", + messages: refineCallMessages, + }, + ); + refineCallFile = fileName; + refineCallTimestamp = timestamp; + } + + const refineStartTime = Date.now(); const refinedResponse = await llmClient.createChatCompletion({ options: { - messages: [ - buildRefineSystemPrompt(), - buildRefineUserPrompt( - instruction, - previouslyExtractedContent, - extractionResponse, - ), - ], + messages: refineCallMessages, response_model: { - schema: schema, + schema, name: "RefinedExtraction", }, temperature: 0.1, @@ -218,32 +451,66 @@ export async function extract({ }, logger, }); + const refineEndTime = Date.now(); - const metadataSchema = z.object({ - progress: z - .string() - .describe( - "progress of what has been extracted so far, as concise as possible", - ), - completed: z - .boolean() - .describe( - "true if the goal is now accomplished. Use this conservatively, only when you are sure that the goal has been completed.", - ), - }); + const { data: refinedResponseData, usage: refinedResponseUsage } = + refinedResponse as LLMParsedResponse; + let refineResponseFile = ""; + if (logInferenceToFile) { + const { fileName } = writeTimestampedTxtFile( + "extract_summary", + "refine_response", + { + requestId, + modelResponse: "refine", + rawResponse: refinedResponseData, + }, + ); + refineResponseFile = fileName; + + appendSummary("extract", { + extract_inference_type: "refine", + timestamp: refineCallTimestamp, + LLM_input_file: refineCallFile, + LLM_output_file: refineResponseFile, + prompt_tokens: refinedResponseUsage?.prompt_tokens ?? 0, + completion_tokens: refinedResponseUsage?.completion_tokens ?? 0, + inference_time_ms: refineEndTime - refineStartTime, + }); + } + + const metadataCallMessages: ChatMessage[] = [ + buildMetadataSystemPrompt(), + buildMetadataPrompt( + instruction, + refinedResponseData, + chunksSeen, + chunksTotal, + ), + ]; + + let metadataCallFile = ""; + let metadataCallTimestamp = ""; + if (logInferenceToFile) { + const { fileName, timestamp } = writeTimestampedTxtFile( + "extract_summary", + "metadata_call", + { + requestId, + modelCall: "metadata", + messages: metadataCallMessages, + }, + ); + metadataCallFile = fileName; + metadataCallTimestamp = timestamp; + } + + const metadataStartTime = Date.now(); const metadataResponse = await llmClient.createChatCompletion({ options: { - messages: [ - buildMetadataSystemPrompt(), - buildMetadataPrompt( - instruction, - refinedResponse, - chunksSeen, - chunksTotal, - ), - ], + messages: metadataCallMessages, response_model: { name: "Metadata", schema: metadataSchema, @@ -256,10 +523,66 @@ export async function extract({ }, logger, }); + const metadataEndTime = Date.now(); + + const { + data: { + completed: metadataResponseCompleted, + progress: metadataResponseProgress, + }, + usage: metadataResponseUsage, + } = metadataResponse as LLMParsedResponse; + + let metadataResponseFile = ""; + if (logInferenceToFile) { + const { fileName } = writeTimestampedTxtFile( + "extract_summary", + "metadata_response", + { + requestId, + modelResponse: "metadata", + completed: metadataResponseCompleted, + progress: metadataResponseProgress, + }, + ); + metadataResponseFile = fileName; + + appendSummary("extract", { + extract_inference_type: "metadata", + timestamp: metadataCallTimestamp, + LLM_input_file: metadataCallFile, + LLM_output_file: metadataResponseFile, + prompt_tokens: metadataResponseUsage?.prompt_tokens ?? 0, + completion_tokens: metadataResponseUsage?.completion_tokens ?? 0, + inference_time_ms: metadataEndTime - metadataStartTime, + }); + } + + const totalPromptTokens = + (extractUsage?.prompt_tokens ?? 0) + + (refinedResponseUsage?.prompt_tokens ?? 0) + + (metadataResponseUsage?.prompt_tokens ?? 0); + + const totalCompletionTokens = + (extractUsage?.completion_tokens ?? 0) + + (refinedResponseUsage?.completion_tokens ?? 0) + + (metadataResponseUsage?.completion_tokens ?? 0); + + const totalInferenceTimeMs = + extractEndTime - + extractStartTime + + (refineEndTime - refineStartTime) + + (metadataEndTime - metadataStartTime); return { - ...refinedResponse, - metadata: metadataResponse, + ...refinedResponseData, + metadata: { + completed: metadataResponseCompleted, + progress: metadataResponseProgress, + }, + prompt_tokens: totalPromptTokens, + completion_tokens: totalCompletionTokens, + inference_time_ms: totalInferenceTimeMs, }; } @@ -272,6 +595,7 @@ export async function observe({ userProvidedInstructions, logger, returnAction = false, + logInferenceToFile = false, }: { instruction: string; domElements: string; @@ -281,6 +605,7 @@ export async function observe({ logger: (message: LogLine) => void; isUsingAccessibilityTree?: boolean; returnAction?: boolean; + logInferenceToFile?: boolean; }) { const observeSchema = z.object({ elements: z @@ -321,49 +646,98 @@ export async function observe({ type ObserveResponse = z.infer; - const observationResponse = - await llmClient.createChatCompletion({ - options: { - messages: [ - buildObserveSystemPrompt( - userProvidedInstructions, - isUsingAccessibilityTree, - ), - buildObserveUserMessage( - instruction, - domElements, - isUsingAccessibilityTree, - ), - ], - response_model: { - schema: observeSchema, - name: "Observation", - }, - temperature: 0.1, - top_p: 1, - frequency_penalty: 0, - presence_penalty: 0, + const messages: ChatMessage[] = [ + buildObserveSystemPrompt( + userProvidedInstructions, + isUsingAccessibilityTree, + ), + buildObserveUserMessage(instruction, domElements, isUsingAccessibilityTree), + ]; + + let callTimestamp = ""; + let callFile = ""; + if (logInferenceToFile) { + const { fileName, timestamp } = writeTimestampedTxtFile( + "observe_summary", + "observe_call", + { requestId, + modelCall: "observe", + messages, }, - logger, + ); + callFile = fileName; + callTimestamp = timestamp; + } + + const start = Date.now(); + const rawResponse = await llmClient.createChatCompletion({ + options: { + messages, + response_model: { + schema: observeSchema, + name: "Observation", + }, + temperature: 0.1, + top_p: 1, + frequency_penalty: 0, + presence_penalty: 0, + requestId, + }, + logger, + }); + const end = Date.now(); + const usageTimeMs = end - start; + + const { data: observeData, usage: observeUsage } = + rawResponse as LLMParsedResponse; + const promptTokens = observeUsage?.prompt_tokens ?? 0; + const completionTokens = observeUsage?.completion_tokens ?? 0; + + let responseFile = ""; + if (logInferenceToFile) { + const { fileName: responseFileName } = writeTimestampedTxtFile( + "observe_summary", + "observe_response", + { + requestId, + modelResponse: "observe", + rawResponse: observeData, + }, + ); + responseFile = responseFileName; + + appendSummary("observe", { + observe_inference_type: "observe", + timestamp: callTimestamp, + LLM_input_file: callFile, + LLM_output_file: responseFile, + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + inference_time_ms: usageTimeMs, }); - const parsedResponse = { - elements: - observationResponse.elements?.map((el) => { - const base = { - elementId: Number(el.elementId), - description: String(el.description), + } + + const parsedElements = + observeData.elements?.map((el) => { + const base = { + elementId: Number(el.elementId), + description: String(el.description), + }; + if (returnAction) { + return { + ...base, + method: String(el.method), + arguments: el.arguments, }; + } + return base; + }) ?? []; - return returnAction - ? { - ...base, - method: String(el.method), - arguments: el.arguments, - } - : base; - }) ?? [], - } satisfies { elements: { elementId: number; description: string }[] }; - - return parsedResponse; + return { + elements: parsedElements, + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + inference_time_ms: usageTimeMs, + }; } diff --git a/lib/inferenceLogUtils.ts b/lib/inferenceLogUtils.ts new file mode 100644 index 000000000..4a5f8693b --- /dev/null +++ b/lib/inferenceLogUtils.ts @@ -0,0 +1,114 @@ +import path from "path"; +import fs from "fs"; + +/** + * Create (or ensure) a parent directory named "inference_summary". + */ +function ensureInferenceSummaryDir(): string { + const inferenceDir = path.join(process.cwd(), "inference_summary"); + if (!fs.existsSync(inferenceDir)) { + fs.mkdirSync(inferenceDir, { recursive: true }); + } + return inferenceDir; +} + +/** + * Appends a new entry to the act_summary.json file, then writes the file back out. + */ +export function appendSummary(inferenceType: string, entry: T) { + const summaryPath = getSummaryJsonPath(inferenceType); + const arrayKey = `${inferenceType}_summary`; + + const existingData = readSummaryFile(inferenceType); + existingData[arrayKey].push(entry); + + fs.writeFileSync(summaryPath, JSON.stringify(existingData, null, 2)); +} + +/** A simple timestamp utility for filenames. */ +function getTimestamp(): string { + return new Date() + .toISOString() + .replace(/[^0-9T]/g, "") + .replace("T", "_"); +} + +/** + * Writes `data` as JSON into a file in `directory`, using a prefix plus timestamp. + * Returns both the file name and the timestamp used, so you can log them. + */ +export function writeTimestampedTxtFile( + directory: string, + prefix: string, + data: unknown, +): { fileName: string; timestamp: string } { + const baseDir = ensureInferenceSummaryDir(); + + const subDir = path.join(baseDir, directory); + if (!fs.existsSync(subDir)) { + fs.mkdirSync(subDir, { recursive: true }); + } + + const timestamp = getTimestamp(); + const fileName = `${timestamp}_${prefix}.txt`; + const filePath = path.join(subDir, fileName); + + fs.writeFileSync( + filePath, + JSON.stringify(data, null, 2).replace(/\\n/g, "\n"), + ); + + return { fileName, timestamp }; +} + +/** + * Returns the path to the `_summary.json` file. + * + * For example, if `inferenceType = "act"`, this will be: + * `./inference_summary/act_summary/act_summary.json` + */ +function getSummaryJsonPath(inferenceType: string): string { + const baseDir = ensureInferenceSummaryDir(); + const subDir = path.join(baseDir, `${inferenceType}_summary`); + if (!fs.existsSync(subDir)) { + fs.mkdirSync(subDir, { recursive: true }); + } + return path.join(subDir, `${inferenceType}_summary.json`); +} + +/** + * Reads the `_summary.json` file, returning an object + * with the top-level array named `_summary`, if it exists. + * + * E.g. if inferenceType is "act", we expect a shape like: + * { + * "act_summary": [ ... ] + * } + * + * If the file or array is missing, returns { "_summary": [] }. + */ +function readSummaryFile(inferenceType: string): Record { + const summaryPath = getSummaryJsonPath(inferenceType); + + // The top-level array key, e.g. "act_summary", "observe_summary", "extract_summary" + const arrayKey = `${inferenceType}_summary`; + + if (!fs.existsSync(summaryPath)) { + return { [arrayKey]: [] }; + } + + try { + const raw = fs.readFileSync(summaryPath, "utf8"); + const parsed = JSON.parse(raw); + if ( + parsed && + typeof parsed === "object" && + Array.isArray(parsed[arrayKey]) + ) { + return parsed; + } + } catch { + // If we fail to parse for any reason, fall back to empty array + } + return { [arrayKey]: [] }; +} diff --git a/lib/llm/AnthropicClient.ts b/lib/llm/AnthropicClient.ts index 824238cbf..a7d97316d 100644 --- a/lib/llm/AnthropicClient.ts +++ b/lib/llm/AnthropicClient.ts @@ -179,7 +179,6 @@ export class AnthropicClient extends LLMClient { }, ], }; - if ( options.image.description && Array.isArray(screenshotMessage.content) @@ -254,6 +253,13 @@ export class AnthropicClient extends LLMClient { }, }); + // We'll compute usage data from the response + const usageData = { + prompt_tokens: response.usage.input_tokens, + completion_tokens: response.usage.output_tokens, + total_tokens: response.usage.input_tokens + response.usage.output_tokens, + }; + const transformedResponse: LLMResponse = { id: response.id, object: "chat.completion", @@ -280,12 +286,7 @@ export class AnthropicClient extends LLMClient { finish_reason: response.stop_reason, }, ], - usage: { - prompt_tokens: response.usage.input_tokens, - completion_tokens: response.usage.output_tokens, - total_tokens: - response.usage.input_tokens + response.usage.output_tokens, - }, + usage: usageData, }; logger({ @@ -308,11 +309,17 @@ export class AnthropicClient extends LLMClient { const toolUse = response.content.find((c) => c.type === "tool_use"); if (toolUse && "input" in toolUse) { const result = toolUse.input; + + const finalParsedResponse = { + data: result, + usage: usageData, + } as unknown as T; + if (this.enableCaching) { - this.cache.set(cacheOptions, result, options.requestId); + this.cache.set(cacheOptions, finalParsedResponse, options.requestId); } - return result as T; // anthropic returns this as `unknown`, so we need to cast + return finalParsedResponse; } else { if (!retries || retries < 5) { return this.createChatCompletion({ diff --git a/lib/llm/LLMClient.ts b/lib/llm/LLMClient.ts index 1f060a510..d893b69ec 100644 --- a/lib/llm/LLMClient.ts +++ b/lib/llm/LLMClient.ts @@ -92,7 +92,9 @@ export abstract class LLMClient { this.userProvidedInstructions = userProvidedInstructions; } - abstract createChatCompletion( - options: CreateChatCompletionOptions, - ): Promise; + abstract createChatCompletion< + T = LLMResponse & { + usage?: LLMResponse["usage"]; + }, + >(options: CreateChatCompletionOptions): Promise; } diff --git a/lib/llm/OpenAIClient.ts b/lib/llm/OpenAIClient.ts index 4089e0085..ff8e8eff8 100644 --- a/lib/llm/OpenAIClient.ts +++ b/lib/llm/OpenAIClient.ts @@ -430,7 +430,10 @@ export class OpenAIClient extends LLMClient { ); } - return parsedData; + return { + data: parsedData, + usage: response.usage, + } as T; } if (this.enableCaching) { diff --git a/types/stagehand.ts b/types/stagehand.ts index 217bacf0d..9edd5b08b 100644 --- a/types/stagehand.ts +++ b/types/stagehand.ts @@ -43,6 +43,7 @@ export interface ConstructorParams { waitForCaptchaSolves?: boolean; localBrowserLaunchOptions?: LocalBrowserLaunchOptions; actTimeoutMs?: number; + logInferenceToFile?: boolean; } export interface InitOptions { @@ -172,6 +173,21 @@ export interface LocalBrowserLaunchOptions { cookies?: Cookie[]; } +export interface StagehandMetrics { + actPromptTokens: number; + actCompletionTokens: number; + actInferenceTimeMs: number; + extractPromptTokens: number; + extractCompletionTokens: number; + extractInferenceTimeMs: number; + observePromptTokens: number; + observeCompletionTokens: number; + observeInferenceTimeMs: number; + totalPromptTokens: number; + totalCompletionTokens: number; + totalInferenceTimeMs: number; +} + /** * Options for executing a task with an agent */ @@ -222,3 +238,9 @@ export interface AgentConfig { */ options?: Record; } + +export enum StagehandFunctionName { + ACT = "ACT", + EXTRACT = "EXTRACT", + OBSERVE = "OBSERVE", +}