diff --git a/extensions/cli/src/commands/ls.ts b/extensions/cli/src/commands/ls.ts index d78ae162d80..529c711b899 100644 --- a/extensions/cli/src/commands/ls.ts +++ b/extensions/cli/src/commands/ls.ts @@ -1,10 +1,9 @@ import { render } from "ink"; import React from "react"; -import { getAccessToken, loadAuthConfig } from "../auth/workos.js"; -import { env } from "../env.js"; import { listSessions, loadSessionById } from "../session.js"; import { SessionSelector } from "../ui/SessionSelector.js"; +import { ApiRequestError, post } from "../util/apiClient.js"; import { logger } from "../util/logger.js"; import { chat } from "./chat.js"; @@ -27,26 +26,19 @@ function setSessionId(sessionId: string): void { } export async function getTunnelForAgent(agentId: string): Promise { - const authConfig = loadAuthConfig(); - const accessToken = getAccessToken(authConfig); - - const resp = await fetch( - new URL(`agents/${encodeURIComponent(agentId)}/tunnel`, env.apiBase), - { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${accessToken}`, - }, - }, - ); - if (!resp.ok) { - throw new Error( - `Failed to get tunnel for agent ${agentId}: ${await resp.text()}`, + try { + const response = await post<{ url: string }>( + `agents/${encodeURIComponent(agentId)}/tunnel`, ); + return response.data.url; + } catch (error) { + if (error instanceof ApiRequestError) { + throw new Error( + `Failed to get tunnel for agent ${agentId}: ${error.response || error.statusText}`, + ); + } + throw error; } - const data = await resp.json(); - return data.url; } /** diff --git a/extensions/cli/src/commands/remote.test.ts b/extensions/cli/src/commands/remote.test.ts index 625ac18a820..a8025a4d7d9 100644 --- a/extensions/cli/src/commands/remote.test.ts +++ b/extensions/cli/src/commands/remote.test.ts @@ -8,11 +8,13 @@ vi.mock("../env.js"); vi.mock("../telemetry/telemetryService.js"); vi.mock("../ui/index.js"); vi.mock("../util/git.js"); +vi.mock("../util/exit.js"); const mockWorkos = vi.mocked(await import("../auth/workos.js")); const mockEnv = vi.mocked(await import("../env.js")); const mockGit = vi.mocked(await import("../util/git.js")); const mockStartRemoteTUIChat = vi.mocked(await import("../ui/index.js")); +const mockExit = vi.mocked(await import("../util/exit.js")); // Mock fetch globally const mockFetch = vi.fn(); @@ -62,6 +64,10 @@ describe("remote command", () => { mockFetch.mockResolvedValue({ ok: true, + headers: { + get: (name: string) => + name === "content-type" ? "application/json" : null, + }, json: async () => ({ id: "test-agent-id", url: "ws://test-url.com", @@ -70,6 +76,9 @@ describe("remote command", () => { }); mockStartRemoteTUIChat.startRemoteTUIChat.mockResolvedValue({} as any); + + // Mock gracefulExit to prevent process.exit during tests + mockExit.gracefulExit.mockResolvedValue(undefined); }); it("should include idempotency key in request body when provided", async () => { @@ -151,10 +160,18 @@ describe("remote command", () => { mockFetch .mockResolvedValueOnce({ ok: true, + headers: { + get: (name: string) => + name === "content-type" ? "application/json" : null, + }, json: async () => ({ url: "ws://tunnel-url.com", port: 9090 }), }) .mockResolvedValue({ ok: true, + headers: { + get: (name: string) => + name === "content-type" ? "application/json" : null, + }, json: async () => ({ id: "test-agent-id", url: "ws://test-url.com", @@ -187,6 +204,10 @@ describe("remote command", () => { mockFetch.mockResolvedValueOnce({ ok: true, + headers: { + get: (name: string) => + name === "content-type" ? "application/json" : null, + }, json: async () => tunnelResponse, }); diff --git a/extensions/cli/src/commands/remote.ts b/extensions/cli/src/commands/remote.ts index 61110f2e14c..919c720fd91 100644 --- a/extensions/cli/src/commands/remote.ts +++ b/extensions/cli/src/commands/remote.ts @@ -1,9 +1,13 @@ import chalk from "chalk"; -import { getAccessToken, loadAuthConfig } from "../auth/workos.js"; import { env } from "../env.js"; import { telemetryService } from "../telemetry/telemetryService.js"; import { startRemoteTUIChat } from "../ui/index.js"; +import { + ApiRequestError, + AuthenticationRequiredError, + post, +} from "../util/apiClient.js"; import { gracefulExit } from "../util/exit.js"; import { getRepoUrl } from "../util/git.js"; import { logger } from "../util/logger.js"; @@ -28,12 +32,6 @@ type AgentCreationResponse = TunnelResponse & { id: string; }; -class AuthenticationRequiredError extends Error { - constructor() { - super("Not authenticated. Please run 'cn login' first."); - } -} - export async function remote( prompt: string | undefined, options: RemoteCommandOptions = {}, @@ -46,19 +44,12 @@ export async function remote( return; } - const accessToken = requireAccessToken(); - if (options.id) { - await connectExistingAgent( - options.id, - accessToken, - actualPrompt, - options.start, - ); + await connectExistingAgent(options.id, actualPrompt, options.start); return; } - await createAndConnectRemoteEnvironment(accessToken, actualPrompt, options); + await createAndConnectRemoteEnvironment(actualPrompt, options); } catch (error) { await handleRemoteError(error); } @@ -99,29 +90,12 @@ async function connectToRemoteUrl( await launchRemoteTUI(remoteUrl, prompt); } -function requireAccessToken(): string { - const authConfig = loadAuthConfig(); - - if (!authConfig) { - throw new AuthenticationRequiredError(); - } - - const accessToken = getAccessToken(authConfig); - - if (!accessToken) { - throw new AuthenticationRequiredError(); - } - - return accessToken; -} - async function connectExistingAgent( agentId: string, - accessToken: string, prompt: string | undefined, startOnly?: boolean, ) { - const tunnel = await fetchAgentTunnel(agentId, accessToken); + const tunnel = await fetchAgentTunnel(agentId); if (startOnly) { printStartJson({ @@ -142,30 +116,24 @@ async function connectExistingAgent( } async function createAndConnectRemoteEnvironment( - accessToken: string, prompt: string | undefined, options: RemoteCommandOptions, ) { const requestBody = buildAgentRequestBody(options, prompt); - const response = await fetch(new URL("agents", env.apiBase), { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${accessToken}`, - }, - body: JSON.stringify(requestBody), - }); - - if (!response.ok) { - const errorText = await response.text(); - throw new Error( - `Failed to create remote environment: ${response.status} ${errorText}`, - ); + let result: AgentCreationResponse; + try { + const response = await post("agents", requestBody); + result = response.data; + } catch (error) { + if (error instanceof ApiRequestError) { + throw new Error( + `Failed to create remote environment: ${error.status} ${error.response || error.statusText}`, + ); + } + throw error; } - const result = (await response.json()) as AgentCreationResponse; - if (options.start) { printStartJson({ status: "success", @@ -216,26 +184,18 @@ function buildAgentRequestBody( return body; } -async function fetchAgentTunnel(agentId: string, accessToken: string) { - const response = await fetch( - new URL(`agents/${agentId}/tunnel`, env.apiBase), - { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${accessToken}`, - }, - }, - ); - - if (!response.ok) { - const errorText = await response.text(); - throw new Error( - `Failed to create tunnel for agent ${agentId}: ${response.status} ${errorText}`, - ); +async function fetchAgentTunnel(agentId: string) { + try { + const response = await post(`agents/${agentId}/tunnel`); + return response.data; + } catch (error) { + if (error instanceof ApiRequestError) { + throw new Error( + `Failed to create tunnel for agent ${agentId}: ${error.status} ${error.response || error.statusText}`, + ); + } + throw error; } - - return (await response.json()) as TunnelResponse; } async function launchRemoteTUI(remoteUrl: string, prompt: string | undefined) { diff --git a/extensions/cli/src/tools/index.tsx b/extensions/cli/src/tools/index.tsx index fda429563c0..36837f23375 100644 --- a/extensions/cli/src/tools/index.tsx +++ b/extensions/cli/src/tools/index.tsx @@ -23,6 +23,7 @@ import { multiEditTool } from "./multiEdit.js"; import { readFileTool } from "./readFile.js"; import { runTerminalCommandTool } from "./runTerminalCommand.js"; import { searchCodeTool } from "./searchCode.js"; +import { statusTool } from "./status.js"; import { type Tool, type ToolCall, @@ -70,6 +71,11 @@ function getDynamicTools(): Tool[] { // Service not ready yet, no dynamic tools } + // Add beta status tool if --beta-status-tool flag is present + if (process.argv.includes("--beta-status-tool")) { + dynamicTools.push(statusTool); + } + return dynamicTools; } diff --git a/extensions/cli/src/tools/status.ts b/extensions/cli/src/tools/status.ts new file mode 100644 index 00000000000..ff85dcd9372 --- /dev/null +++ b/extensions/cli/src/tools/status.ts @@ -0,0 +1,78 @@ +import { + ApiRequestError, + AuthenticationRequiredError, + post, +} from "../util/apiClient.js"; +import { logger } from "../util/logger.js"; + +import { Tool } from "./types.js"; + +/** + * Extract the agent ID from the --id command line flag + */ +function getAgentIdFromArgs(): string | undefined { + const args = process.argv; + const idIndex = args.indexOf("--id"); + if (idIndex !== -1 && idIndex + 1 < args.length) { + return args[idIndex + 1]; + } + return undefined; +} + +export const statusTool: Tool = { + name: "Status", + displayName: "Status", + description: `Set the current status of your task for the user to see + +The available statuses are: +- PLANNING: You are creating a plan before beginning implementation +- WORKING: The task is in progress +- DONE: The task is complete +- BLOCKED: You need further information from the user in order to proceed + +You should use this tool to notify the user whenever the state of your work changes. By default, the status is assumed to be "PLANNING" prior to you setting a different status.`, + parameters: { + type: "object", + required: ["status"], + properties: { + status: { + type: "string", + description: "The status value to set", + }, + }, + }, + readonly: true, + isBuiltIn: true, + run: async (args: { status: string }): Promise => { + try { + // Get agent ID from --id flag + const agentId = getAgentIdFromArgs(); + if (!agentId) { + const errorMessage = + "Agent ID is required. Please use the --id flag with cn serve."; + logger.error(errorMessage); + return `Error: ${errorMessage}`; + } + + // Call the API endpoint using shared client + await post(`agents/${agentId}/status`, { status: args.status }); + + logger.info(`Status: ${args.status}`); + return `Status set: ${args.status}`; + } catch (error) { + if (error instanceof AuthenticationRequiredError) { + logger.error(error.message); + return "Error: Authentication required"; + } + + if (error instanceof ApiRequestError) { + return `Error setting status: ${error.status} ${error.response || error.statusText}`; + } + + const errorMessage = + error instanceof Error ? error.message : String(error); + logger.error(`Error setting status: ${errorMessage}`); + return `Error setting status: ${errorMessage}`; + } + }, +}; diff --git a/extensions/cli/src/util/apiClient.test.ts b/extensions/cli/src/util/apiClient.test.ts new file mode 100644 index 00000000000..09bb12ed54c --- /dev/null +++ b/extensions/cli/src/util/apiClient.test.ts @@ -0,0 +1,283 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; + +import { + ApiRequestError, + AuthenticationRequiredError, + del, + get, + makeAuthenticatedRequest, + post, + put, +} from "./apiClient.js"; + +// Mock the dependencies +vi.mock("../auth/workos.js", () => ({ + loadAuthConfig: vi.fn(), + getAccessToken: vi.fn(), +})); + +vi.mock("../env.js", () => ({ + env: { + apiBase: "https://api.continue.dev", + }, +})); + +vi.mock("./logger.js", () => ({ + logger: { + debug: vi.fn(), + error: vi.fn(), + }, +})); + +// Mock fetch globally +global.fetch = vi.fn(); + +describe("apiClient", () => { + let mockLoadAuthConfig: any; + let mockGetAccessToken: any; + const mockFetch = vi.mocked(global.fetch); + + beforeEach(async () => { + vi.clearAllMocks(); + + // Get mocked functions + const authModule = await import("../auth/workos.js"); + mockLoadAuthConfig = vi.mocked(authModule.loadAuthConfig); + mockGetAccessToken = vi.mocked(authModule.getAccessToken); + + // Setup default successful authentication + mockLoadAuthConfig.mockReturnValue({ some: "config" }); + mockGetAccessToken.mockReturnValue("test-access-token"); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("makeAuthenticatedRequest", () => { + test("should make successful API request", async () => { + mockFetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: "OK", + headers: { get: vi.fn().mockReturnValue("application/json") }, + json: vi.fn().mockResolvedValue({ data: "test" }), + } as unknown as Response); + + const result = await makeAuthenticatedRequest("test-endpoint", { + method: "POST", + body: { key: "value" }, + }); + + expect(mockFetch).toHaveBeenCalledWith( + new URL("test-endpoint", "https://api.continue.dev"), + { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: "Bearer test-access-token", + }, + body: JSON.stringify({ key: "value" }), + }, + ); + + expect(result).toEqual({ + data: { data: "test" }, + status: 200, + ok: true, + }); + }); + + test("should handle non-JSON response", async () => { + mockFetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: "OK", + headers: { get: vi.fn().mockReturnValue("text/plain") }, + text: vi.fn().mockResolvedValue("plain text response"), + } as unknown as Response); + + const result = await makeAuthenticatedRequest("test-endpoint"); + + expect(result.data).toBe("plain text response"); + }); + + test("should throw AuthenticationRequiredError when no auth config", async () => { + mockLoadAuthConfig.mockReturnValue(null); + + await expect(makeAuthenticatedRequest("test-endpoint")).rejects.toThrow( + AuthenticationRequiredError, + ); + }); + + test("should throw AuthenticationRequiredError when no access token", async () => { + mockGetAccessToken.mockReturnValue(null); + + await expect(makeAuthenticatedRequest("test-endpoint")).rejects.toThrow( + AuthenticationRequiredError, + ); + }); + + test("should throw ApiRequestError on API error", async () => { + mockFetch.mockResolvedValue({ + ok: false, + status: 404, + statusText: "Not Found", + headers: { get: vi.fn() }, + text: vi.fn().mockResolvedValue("Resource not found"), + } as unknown as Response); + + await expect(makeAuthenticatedRequest("test-endpoint")).rejects.toThrow( + ApiRequestError, + ); + }); + + test("should handle network error", async () => { + mockFetch.mockRejectedValue(new Error("Network error")); + + await expect(makeAuthenticatedRequest("test-endpoint")).rejects.toThrow( + "Request failed: Network error", + ); + }); + + test("should handle string body", async () => { + mockFetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: "OK", + headers: { get: vi.fn().mockReturnValue("application/json") }, + json: vi.fn().mockResolvedValue({ success: true }), + } as unknown as Response); + + await makeAuthenticatedRequest("test-endpoint", { + method: "POST", + body: "raw string body", + }); + + expect(mockFetch).toHaveBeenCalledWith( + expect.any(URL), + expect.objectContaining({ + body: "raw string body", + }), + ); + }); + + test("should merge custom headers", async () => { + mockFetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: "OK", + headers: { get: vi.fn().mockReturnValue("application/json") }, + json: vi.fn().mockResolvedValue({}), + } as unknown as Response); + + await makeAuthenticatedRequest("test-endpoint", { + headers: { "Custom-Header": "custom-value" }, + }); + + expect(mockFetch).toHaveBeenCalledWith( + expect.any(URL), + expect.objectContaining({ + headers: { + "Content-Type": "application/json", + Authorization: "Bearer test-access-token", + "Custom-Header": "custom-value", + }, + }), + ); + }); + }); + + describe("convenience methods", () => { + beforeEach(() => { + mockFetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: "OK", + headers: { get: vi.fn().mockReturnValue("application/json") }, + json: vi.fn().mockResolvedValue({ success: true }), + } as unknown as Response); + }); + + test("get method should make GET request", async () => { + await get("test-endpoint"); + + expect(mockFetch).toHaveBeenCalledWith( + expect.any(URL), + expect.objectContaining({ method: "GET" }), + ); + }); + + test("post method should make POST request", async () => { + await post("test-endpoint", { data: "test" }); + + expect(mockFetch).toHaveBeenCalledWith( + expect.any(URL), + expect.objectContaining({ + method: "POST", + body: JSON.stringify({ data: "test" }), + }), + ); + }); + + test("put method should make PUT request", async () => { + await put("test-endpoint", { data: "test" }); + + expect(mockFetch).toHaveBeenCalledWith( + expect.any(URL), + expect.objectContaining({ + method: "PUT", + body: JSON.stringify({ data: "test" }), + }), + ); + }); + + test("del method should make DELETE request", async () => { + await del("test-endpoint"); + + expect(mockFetch).toHaveBeenCalledWith( + expect.any(URL), + expect.objectContaining({ method: "DELETE" }), + ); + }); + }); + + describe("error classes", () => { + test("AuthenticationRequiredError should have correct properties", () => { + const error = new AuthenticationRequiredError(); + + expect(error.name).toBe("AuthenticationRequiredError"); + expect(error.message).toBe( + "Not authenticated. Please run 'cn login' first.", + ); + }); + + test("AuthenticationRequiredError should accept custom message", () => { + const customMessage = "Custom auth error"; + const error = new AuthenticationRequiredError(customMessage); + + expect(error.message).toBe(customMessage); + }); + + test("ApiRequestError should have correct properties", () => { + const error = new ApiRequestError(404, "Not Found", "Resource not found"); + + expect(error.name).toBe("ApiRequestError"); + expect(error.status).toBe(404); + expect(error.statusText).toBe("Not Found"); + expect(error.response).toBe("Resource not found"); + expect(error.message).toContain("404"); + expect(error.message).toContain("Not Found"); + expect(error.message).toContain("Resource not found"); + }); + + test("ApiRequestError should work without response text", () => { + const error = new ApiRequestError(500, "Internal Server Error"); + + expect(error.response).toBeUndefined(); + expect(error.message).toContain("500"); + expect(error.message).toContain("Internal Server Error"); + expect(error.message).not.toContain("undefined"); + }); + }); +}); diff --git a/extensions/cli/src/util/apiClient.ts b/extensions/cli/src/util/apiClient.ts new file mode 100644 index 00000000000..16119c4a003 --- /dev/null +++ b/extensions/cli/src/util/apiClient.ts @@ -0,0 +1,193 @@ +/* eslint-disable max-classes-per-file */ +import { getAccessToken, loadAuthConfig } from "../auth/workos.js"; +import { env } from "../env.js"; + +import { logger } from "./logger.js"; + +export interface ApiRequestOptions { + method?: "GET" | "POST" | "PUT" | "DELETE" | "PATCH"; + body?: Record | string; + headers?: Record; +} + +export interface ApiResponse { + data: T; + status: number; + ok: boolean; +} + +export interface ApiError extends Error { + status: number; + statusText: string; + response?: string; +} + +/** + * Authentication error thrown when user is not authenticated + */ +export class AuthenticationRequiredError extends Error { + constructor(message = "Not authenticated. Please run 'cn login' first.") { + super(message); + this.name = "AuthenticationRequiredError"; + } +} + +/** + * API error thrown when the request fails + */ +export class ApiRequestError extends Error implements ApiError { + status: number; + statusText: string; + response?: string; + + constructor(status: number, statusText: string, response?: string) { + const message = response + ? `API request failed: ${status} ${statusText} - ${response}` + : `API request failed: ${status} ${statusText}`; + super(message); + this.name = "ApiRequestError"; + this.status = status; + this.statusText = statusText; + this.response = response; + } +} + +/** + * Make an authenticated API request to the Continue API + * Handles authentication, error handling, and response parsing + */ +export async function makeAuthenticatedRequest( + endpoint: string, + options: ApiRequestOptions = {}, +): Promise> { + // Handle authentication + const authConfig = loadAuthConfig(); + if (!authConfig) { + throw new AuthenticationRequiredError(); + } + + const accessToken = getAccessToken(authConfig); + if (!accessToken) { + throw new AuthenticationRequiredError( + "No access token available. Please run 'cn login' first.", + ); + } + + // Prepare request options + const { method = "GET", body, headers = {} } = options; + + const requestOptions: RequestInit = { + method, + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${accessToken}`, + ...headers, + }, + }; + + // Add body if provided + if (body) { + requestOptions.body = + typeof body === "string" ? body : JSON.stringify(body); + } + + try { + // Make the request + const url = new URL(endpoint, env.apiBase); + logger.debug(`Making ${method} request to: ${url.toString()}`); + + const response = await fetch(url, requestOptions); + + // Handle error responses + if (!response.ok) { + const errorText = await response.text(); + logger.error(`API request failed: ${response.status} ${errorText}`); + throw new ApiRequestError( + response.status, + response.statusText, + errorText, + ); + } + + // Parse response + let data: T; + const contentType = response.headers.get("content-type"); + if (contentType && contentType.includes("application/json")) { + data = await response.json(); + } else { + // If not JSON, return the text as data + data = (await response.text()) as T; + } + + logger.debug(`API request successful: ${response.status}`); + + return { + data, + status: response.status, + ok: response.ok, + }; + } catch (error) { + // Re-throw our custom errors + if ( + error instanceof AuthenticationRequiredError || + error instanceof ApiRequestError + ) { + throw error; + } + + // Handle network/other errors + const errorMessage = error instanceof Error ? error.message : String(error); + logger.error(`Network/request error: ${errorMessage}`); + throw new Error(`Request failed: ${errorMessage}`); + } +} + +/** + * Convenience function for GET requests + */ +export async function get( + endpoint: string, + headers?: Record, +): Promise> { + return makeAuthenticatedRequest(endpoint, { method: "GET", headers }); +} + +/** + * Convenience function for POST requests + */ +export async function post( + endpoint: string, + body?: Record | string, + headers?: Record, +): Promise> { + return makeAuthenticatedRequest(endpoint, { + method: "POST", + body, + headers, + }); +} + +/** + * Convenience function for PUT requests + */ +export async function put( + endpoint: string, + body?: Record | string, + headers?: Record, +): Promise> { + return makeAuthenticatedRequest(endpoint, { + method: "PUT", + body, + headers, + }); +} + +/** + * Convenience function for DELETE requests + */ +export async function del( + endpoint: string, + headers?: Record, +): Promise> { + return makeAuthenticatedRequest(endpoint, { method: "DELETE", headers }); +}