From 3207718f70d47c308e6882bfa4f7fe9e9b9b059b Mon Sep 17 00:00:00 2001 From: Luis Catacora Date: Mon, 20 Oct 2025 12:38:42 -0400 Subject: [PATCH 1/2] Implement async polling for Replicate provider --- packages/inference/src/providers/replicate.ts | 373 +++++++++++++----- 1 file changed, 270 insertions(+), 103 deletions(-) diff --git a/packages/inference/src/providers/replicate.ts b/packages/inference/src/providers/replicate.ts index 75496fecee..e7dc464182 100644 --- a/packages/inference/src/providers/replicate.ts +++ b/packages/inference/src/providers/replicate.ts @@ -19,7 +19,7 @@ import { isUrl } from "../lib/isUrl.js"; import type { BodyParams, HeaderParams, RequestArgs, UrlParams } from "../types.js"; import { omit } from "../utils/omit.js"; import { - TaskProviderHelper, + TaskProviderHelper, type AutomaticSpeechRecognitionTaskHelper, type ImageToImageTaskHelper, type TextToImageTaskHelper, @@ -30,13 +30,64 @@ import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpe import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks"; import { base64FromBytes } from "../utils/base64FromBytes.js"; export interface ReplicateOutput { - output?: string | string[]; + output?: string | string[]; +} + +type ReplicatePredictionStatus = + | "starting" + | "processing" + | "succeeded" + | "failed" + | "canceled" + | "queued"; + +interface ReplicateAsyncResponse extends ReplicateOutput { + id?: string; + status?: ReplicatePredictionStatus; + error?: unknown; + urls?: { + get?: string; + }; +} + +const POLLING_INTERVAL_MS = 1_000; + +function headersInitToRecord(headers?: HeadersInit): Record { + if (!headers) { + return {}; + } + if (headers instanceof Headers) { + return Object.fromEntries(headers.entries()); + } + if (Array.isArray(headers)) { + return Object.fromEntries(headers); + } + return { ...headers }; +} + +function getErrorMessage(error: unknown): string | undefined { + if (!error) { + return undefined; + } + if (typeof error === "string") { + return error; + } + if (typeof error === "object" && "message" in error && typeof error.message === "string") { + return error.message; + } + return undefined; +} + +async function sleep(ms: number): Promise { + await new Promise((resolve) => { + setTimeout(resolve, ms); + }); } abstract class ReplicateTask extends TaskProviderHelper { - constructor(url?: string) { - super("replicate", url || "https://api.replicate.com"); - } + constructor(url?: string) { + super("replicate", url || "https://api.replicate.com"); + } makeRoute(params: UrlParams): string { if (params.model.includes(":")) { @@ -62,18 +113,111 @@ abstract class ReplicateTask extends TaskProviderHelper { return headers; } - override makeUrl(params: UrlParams): string { - const baseUrl = this.makeBaseUrl(params); - if (params.model.includes(":")) { - return `${baseUrl}/v1/predictions`; - } - return `${baseUrl}/v1/models/${params.model}/predictions`; - } + override makeUrl(params: UrlParams): string { + const baseUrl = this.makeBaseUrl(params); + if (params.model.includes(":")) { + return `${baseUrl}/v1/predictions`; + } + return `${baseUrl}/v1/models/${params.model}/predictions`; + } + + protected async ensureFinalResponse( + response: ReplicateOutput | Blob | ReplicateAsyncResponse, + requestUrl?: string, + headers?: HeadersInit + ): Promise { + if (response instanceof Blob) { + return response; + } + + if (!response || typeof response !== "object") { + return response as ReplicateOutput; + } + + const status = "status" in response ? response.status : undefined; + + if (!status || status === "succeeded") { + return response as ReplicateOutput; + } + + if (status === "failed" || status === "canceled") { + const message = getErrorMessage((response as ReplicateAsyncResponse).error); + throw new InferenceClientProviderOutputError( + `Replicate prediction ${status}${message ? `: ${message}` : ""}` + ); + } + + const pollUrl = this.getPollUrl(response as ReplicateAsyncResponse, requestUrl); + if (!pollUrl) { + throw new InferenceClientProviderOutputError( + "Received incomplete response from Replicate API: missing polling URL" + ); + } + + const headerRecord = headersInitToRecord(headers); + const pollHeaders: Record = {}; + if (headerRecord.Authorization) { + pollHeaders.Authorization = headerRecord.Authorization; + } + pollHeaders.Accept = "application/json"; + + // Poll the prediction endpoint until completion + while (true) { + await sleep(POLLING_INTERVAL_MS); + const pollResponse = await fetch(pollUrl, { + method: "GET", + headers: pollHeaders, + }); + + if (!pollResponse.ok) { + throw new InferenceClientProviderOutputError( + `Failed to poll Replicate prediction status: HTTP ${pollResponse.status}` + ); + } + + const prediction = (await pollResponse.json()) as ReplicateAsyncResponse; + const predictionStatus = prediction.status; + + if (!predictionStatus || predictionStatus === "succeeded") { + return prediction as ReplicateOutput; + } + + if (predictionStatus === "failed" || predictionStatus === "canceled") { + const message = getErrorMessage(prediction.error); + throw new InferenceClientProviderOutputError( + `Replicate prediction ${predictionStatus}${message ? `: ${message}` : ""}` + ); + } + } + } + + private getPollUrl(response: ReplicateAsyncResponse, requestUrl?: string): string | undefined { + if (response.urls && typeof response.urls === "object" && typeof response.urls.get === "string") { + return response.urls.get; + } + + if (!response.id || !requestUrl) { + return undefined; + } + + try { + const url = new URL(requestUrl); + const pathname = url.pathname.replace(/\/$/, ""); + if (pathname.endsWith("/predictions")) { + url.pathname = `${pathname}/${response.id}`; + return url.toString(); + } + } catch { + return undefined; + } + + return undefined; + } } export class ReplicateTextToImageTask extends ReplicateTask implements TextToImageTaskHelper { - override preparePayload(params: BodyParams): Record { - return { + override preparePayload(params: BodyParams): Record { + return { input: { ...omit(params.args, ["inputs", "parameters"]), ...(params.args.parameters as Record), @@ -87,33 +231,34 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma }; } - override async getResponse( - res: ReplicateOutput | Blob, - url?: string, - headers?: Record, - outputType?: "url" | "blob" | "json" - ): Promise> { - void url; - void headers; - if ( - typeof res === "object" && - "output" in res && - Array.isArray(res.output) && - res.output.length > 0 && - typeof res.output[0] === "string" - ) { - if (outputType === "json") { - return { ...res }; - } - if (outputType === "url") { - return res.output[0]; - } - const urlResponse = await fetch(res.output[0]); - return await urlResponse.blob(); - } + override async getResponse( + res: ReplicateOutput | Blob, + url?: string, + headers?: Record, + outputType?: "url" | "blob" | "json" + ): Promise> { + void url; + const finalResponse = (await this.ensureFinalResponse(res, url, headers)) as ReplicateOutput; - throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-image API"); - } + if ( + typeof finalResponse === "object" && + "output" in finalResponse && + Array.isArray(finalResponse.output) && + finalResponse.output.length > 0 && + typeof finalResponse.output[0] === "string" + ) { + if (outputType === "json") { + return { ...finalResponse }; + } + if (outputType === "url") { + return finalResponse.output[0]; + } + const urlResponse = await fetch(finalResponse.output[0]); + return await urlResponse.blob(); + } + + throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-image API"); + } } export class ReplicateTextToSpeechTask extends ReplicateTask { @@ -130,40 +275,51 @@ export class ReplicateTextToSpeechTask extends ReplicateTask { return payload; } - override async getResponse(response: ReplicateOutput): Promise { - if (response instanceof Blob) { - return response; - } - if (response && typeof response === "object") { - if ("output" in response) { - if (typeof response.output === "string") { - const urlResponse = await fetch(response.output); - return await urlResponse.blob(); - } else if (Array.isArray(response.output)) { - const urlResponse = await fetch(response.output[0]); - return await urlResponse.blob(); - } - } - } - throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-speech API"); - } + override async getResponse( + response: ReplicateOutput | Blob, + url?: string, + headers?: HeadersInit + ): Promise { + const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput | Blob; + + if (finalResponse instanceof Blob) { + return finalResponse; + } + if (finalResponse && typeof finalResponse === "object") { + if ("output" in finalResponse) { + if (typeof finalResponse.output === "string") { + const urlResponse = await fetch(finalResponse.output); + return await urlResponse.blob(); + } else if (Array.isArray(finalResponse.output)) { + const urlResponse = await fetch(finalResponse.output[0]); + return await urlResponse.blob(); + } + } + } + throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-speech API"); + } } export class ReplicateTextToVideoTask extends ReplicateTask implements TextToVideoTaskHelper { - override async getResponse(response: ReplicateOutput): Promise { - if ( - typeof response === "object" && - !!response && - "output" in response && - typeof response.output === "string" && - isUrl(response.output) - ) { - const urlResponse = await fetch(response.output); - return await urlResponse.blob(); - } + override async getResponse( + response: ReplicateOutput | Blob, + url?: string, + headers?: HeadersInit + ): Promise { + const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput; + if ( + typeof finalResponse === "object" && + !!finalResponse && + "output" in finalResponse && + typeof finalResponse.output === "string" && + isUrl(finalResponse.output) + ) { + const urlResponse = await fetch(finalResponse.output); + return await urlResponse.blob(); + } - throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-video API"); - } + throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-video API"); + } } export class ReplicateAutomaticSpeechRecognitionTask @@ -199,16 +355,22 @@ export class ReplicateAutomaticSpeechRecognitionTask }; } - override async getResponse(response: ReplicateOutput): Promise { - if (typeof response?.output === "string") return { text: response.output }; - if (Array.isArray(response?.output) && typeof response.output[0] === "string") return { text: response.output[0] }; + override async getResponse( + response: ReplicateOutput | Blob, + url?: string, + headers?: HeadersInit + ): Promise { + const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput; + if (typeof finalResponse?.output === "string") return { text: finalResponse.output }; + if (Array.isArray(finalResponse?.output) && typeof finalResponse.output[0] === "string") + return { text: finalResponse.output[0] }; - const out = response?.output as - | undefined - | { - transcription?: string; - translation?: string; - txt_file?: string; + const out = finalResponse?.output as + | undefined + | { + transcription?: string; + translation?: string; + txt_file?: string; }; if (out && typeof out === "object") { if (typeof out.transcription === "string") return { text: out.transcription }; @@ -254,30 +416,35 @@ export class ReplicateImageToImageTask extends ReplicateTask implements ImageToI }; } - override async getResponse(response: ReplicateOutput): Promise { - if ( - typeof response === "object" && - !!response && - "output" in response && - Array.isArray(response.output) && - response.output.length > 0 && - typeof response.output[0] === "string" - ) { - const urlResponse = await fetch(response.output[0]); - return await urlResponse.blob(); - } + override async getResponse( + response: ReplicateOutput | Blob, + url?: string, + headers?: HeadersInit + ): Promise { + const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput; + if ( + typeof finalResponse === "object" && + !!finalResponse && + "output" in finalResponse && + Array.isArray(finalResponse.output) && + finalResponse.output.length > 0 && + typeof finalResponse.output[0] === "string" + ) { + const urlResponse = await fetch(finalResponse.output[0]); + return await urlResponse.blob(); + } - if ( - typeof response === "object" && - !!response && - "output" in response && - typeof response.output === "string" && - isUrl(response.output) - ) { - const urlResponse = await fetch(response.output); - return await urlResponse.blob(); - } + if ( + typeof finalResponse === "object" && + !!finalResponse && + "output" in finalResponse && + typeof finalResponse.output === "string" && + isUrl(finalResponse.output) + ) { + const urlResponse = await fetch(finalResponse.output); + return await urlResponse.blob(); + } - throw new InferenceClientProviderOutputError("Received malformed response from Replicate image-to-image API"); - } + throw new InferenceClientProviderOutputError("Received malformed response from Replicate image-to-image API"); + } } From 28decb35200e640acf204a3882bdf38f833a8273 Mon Sep 17 00:00:00 2001 From: Luis C Date: Mon, 20 Oct 2025 13:31:23 -0400 Subject: [PATCH 2/2] fix format --- packages/inference/src/providers/replicate.ts | 500 +++++++++--------- 1 file changed, 240 insertions(+), 260 deletions(-) diff --git a/packages/inference/src/providers/replicate.ts b/packages/inference/src/providers/replicate.ts index e7dc464182..4e93074e4e 100644 --- a/packages/inference/src/providers/replicate.ts +++ b/packages/inference/src/providers/replicate.ts @@ -19,7 +19,7 @@ import { isUrl } from "../lib/isUrl.js"; import type { BodyParams, HeaderParams, RequestArgs, UrlParams } from "../types.js"; import { omit } from "../utils/omit.js"; import { - TaskProviderHelper, + TaskProviderHelper, type AutomaticSpeechRecognitionTaskHelper, type ImageToImageTaskHelper, type TextToImageTaskHelper, @@ -30,64 +30,58 @@ import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpe import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks"; import { base64FromBytes } from "../utils/base64FromBytes.js"; export interface ReplicateOutput { - output?: string | string[]; + output?: string | string[]; } -type ReplicatePredictionStatus = - | "starting" - | "processing" - | "succeeded" - | "failed" - | "canceled" - | "queued"; +type ReplicatePredictionStatus = "starting" | "processing" | "succeeded" | "failed" | "canceled" | "queued"; interface ReplicateAsyncResponse extends ReplicateOutput { - id?: string; - status?: ReplicatePredictionStatus; - error?: unknown; - urls?: { - get?: string; - }; + id?: string; + status?: ReplicatePredictionStatus; + error?: unknown; + urls?: { + get?: string; + }; } const POLLING_INTERVAL_MS = 1_000; function headersInitToRecord(headers?: HeadersInit): Record { - if (!headers) { - return {}; - } - if (headers instanceof Headers) { - return Object.fromEntries(headers.entries()); - } - if (Array.isArray(headers)) { - return Object.fromEntries(headers); - } - return { ...headers }; + if (!headers) { + return {}; + } + if (headers instanceof Headers) { + return Object.fromEntries(headers.entries()); + } + if (Array.isArray(headers)) { + return Object.fromEntries(headers); + } + return { ...headers }; } function getErrorMessage(error: unknown): string | undefined { - if (!error) { - return undefined; - } - if (typeof error === "string") { - return error; - } - if (typeof error === "object" && "message" in error && typeof error.message === "string") { - return error.message; - } - return undefined; + if (!error) { + return undefined; + } + if (typeof error === "string") { + return error; + } + if (typeof error === "object" && "message" in error && typeof error.message === "string") { + return error.message; + } + return undefined; } async function sleep(ms: number): Promise { - await new Promise((resolve) => { - setTimeout(resolve, ms); - }); + await new Promise((resolve) => { + setTimeout(resolve, ms); + }); } abstract class ReplicateTask extends TaskProviderHelper { - constructor(url?: string) { - super("replicate", url || "https://api.replicate.com"); - } + constructor(url?: string) { + super("replicate", url || "https://api.replicate.com"); + } makeRoute(params: UrlParams): string { if (params.model.includes(":")) { @@ -113,111 +107,109 @@ abstract class ReplicateTask extends TaskProviderHelper { return headers; } - override makeUrl(params: UrlParams): string { - const baseUrl = this.makeBaseUrl(params); - if (params.model.includes(":")) { - return `${baseUrl}/v1/predictions`; - } - return `${baseUrl}/v1/models/${params.model}/predictions`; - } - - protected async ensureFinalResponse( - response: ReplicateOutput | Blob | ReplicateAsyncResponse, - requestUrl?: string, - headers?: HeadersInit - ): Promise { - if (response instanceof Blob) { - return response; - } - - if (!response || typeof response !== "object") { - return response as ReplicateOutput; - } - - const status = "status" in response ? response.status : undefined; - - if (!status || status === "succeeded") { - return response as ReplicateOutput; - } - - if (status === "failed" || status === "canceled") { - const message = getErrorMessage((response as ReplicateAsyncResponse).error); - throw new InferenceClientProviderOutputError( - `Replicate prediction ${status}${message ? `: ${message}` : ""}` - ); - } - - const pollUrl = this.getPollUrl(response as ReplicateAsyncResponse, requestUrl); - if (!pollUrl) { - throw new InferenceClientProviderOutputError( - "Received incomplete response from Replicate API: missing polling URL" - ); - } - - const headerRecord = headersInitToRecord(headers); - const pollHeaders: Record = {}; - if (headerRecord.Authorization) { - pollHeaders.Authorization = headerRecord.Authorization; - } - pollHeaders.Accept = "application/json"; - - // Poll the prediction endpoint until completion - while (true) { - await sleep(POLLING_INTERVAL_MS); - const pollResponse = await fetch(pollUrl, { - method: "GET", - headers: pollHeaders, - }); - - if (!pollResponse.ok) { - throw new InferenceClientProviderOutputError( - `Failed to poll Replicate prediction status: HTTP ${pollResponse.status}` - ); - } - - const prediction = (await pollResponse.json()) as ReplicateAsyncResponse; - const predictionStatus = prediction.status; - - if (!predictionStatus || predictionStatus === "succeeded") { - return prediction as ReplicateOutput; - } - - if (predictionStatus === "failed" || predictionStatus === "canceled") { - const message = getErrorMessage(prediction.error); - throw new InferenceClientProviderOutputError( - `Replicate prediction ${predictionStatus}${message ? `: ${message}` : ""}` - ); - } - } - } - - private getPollUrl(response: ReplicateAsyncResponse, requestUrl?: string): string | undefined { - if (response.urls && typeof response.urls === "object" && typeof response.urls.get === "string") { - return response.urls.get; - } - - if (!response.id || !requestUrl) { - return undefined; - } - - try { - const url = new URL(requestUrl); - const pathname = url.pathname.replace(/\/$/, ""); - if (pathname.endsWith("/predictions")) { - url.pathname = `${pathname}/${response.id}`; - return url.toString(); - } - } catch { - return undefined; - } - - return undefined; - } + override makeUrl(params: UrlParams): string { + const baseUrl = this.makeBaseUrl(params); + if (params.model.includes(":")) { + return `${baseUrl}/v1/predictions`; + } + return `${baseUrl}/v1/models/${params.model}/predictions`; + } + + protected async ensureFinalResponse( + response: ReplicateOutput | Blob | ReplicateAsyncResponse, + requestUrl?: string, + headers?: HeadersInit + ): Promise { + if (response instanceof Blob) { + return response; + } + + if (!response || typeof response !== "object") { + return response as ReplicateOutput; + } + + const status = "status" in response ? response.status : undefined; + + if (!status || status === "succeeded") { + return response as ReplicateOutput; + } + + if (status === "failed" || status === "canceled") { + const message = getErrorMessage((response as ReplicateAsyncResponse).error); + throw new InferenceClientProviderOutputError(`Replicate prediction ${status}${message ? `: ${message}` : ""}`); + } + + const pollUrl = this.getPollUrl(response as ReplicateAsyncResponse, requestUrl); + if (!pollUrl) { + throw new InferenceClientProviderOutputError( + "Received incomplete response from Replicate API: missing polling URL" + ); + } + + const headerRecord = headersInitToRecord(headers); + const pollHeaders: Record = {}; + if (headerRecord.Authorization) { + pollHeaders.Authorization = headerRecord.Authorization; + } + pollHeaders.Accept = "application/json"; + + // Poll the prediction endpoint until completion + while (true) { + await sleep(POLLING_INTERVAL_MS); + const pollResponse = await fetch(pollUrl, { + method: "GET", + headers: pollHeaders, + }); + + if (!pollResponse.ok) { + throw new InferenceClientProviderOutputError( + `Failed to poll Replicate prediction status: HTTP ${pollResponse.status}` + ); + } + + const prediction = (await pollResponse.json()) as ReplicateAsyncResponse; + const predictionStatus = prediction.status; + + if (!predictionStatus || predictionStatus === "succeeded") { + return prediction as ReplicateOutput; + } + + if (predictionStatus === "failed" || predictionStatus === "canceled") { + const message = getErrorMessage(prediction.error); + throw new InferenceClientProviderOutputError( + `Replicate prediction ${predictionStatus}${message ? `: ${message}` : ""}` + ); + } + } + } + + private getPollUrl(response: ReplicateAsyncResponse, requestUrl?: string): string | undefined { + if (response.urls && typeof response.urls === "object" && typeof response.urls.get === "string") { + return response.urls.get; + } + + if (!response.id || !requestUrl) { + return undefined; + } + + try { + const url = new URL(requestUrl); + const pathname = url.pathname.replace(/\/$/, ""); + if (pathname.endsWith("/predictions")) { + url.pathname = `${pathname}/${response.id}`; + return url.toString(); + } + } catch { + return undefined; + } + + return undefined; + } } export class ReplicateTextToImageTask extends ReplicateTask implements TextToImageTaskHelper { - override preparePayload(params: BodyParams): Record { - return { + override preparePayload(params: BodyParams): Record { + return { input: { ...omit(params.args, ["inputs", "parameters"]), ...(params.args.parameters as Record), @@ -231,34 +223,34 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma }; } - override async getResponse( - res: ReplicateOutput | Blob, - url?: string, - headers?: Record, - outputType?: "url" | "blob" | "json" - ): Promise> { - void url; - const finalResponse = (await this.ensureFinalResponse(res, url, headers)) as ReplicateOutput; - - if ( - typeof finalResponse === "object" && - "output" in finalResponse && - Array.isArray(finalResponse.output) && - finalResponse.output.length > 0 && - typeof finalResponse.output[0] === "string" - ) { - if (outputType === "json") { - return { ...finalResponse }; - } - if (outputType === "url") { - return finalResponse.output[0]; - } - const urlResponse = await fetch(finalResponse.output[0]); - return await urlResponse.blob(); - } - - throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-image API"); - } + override async getResponse( + res: ReplicateOutput | Blob, + url?: string, + headers?: Record, + outputType?: "url" | "blob" | "json" + ): Promise> { + void url; + const finalResponse = (await this.ensureFinalResponse(res, url, headers)) as ReplicateOutput; + + if ( + typeof finalResponse === "object" && + "output" in finalResponse && + Array.isArray(finalResponse.output) && + finalResponse.output.length > 0 && + typeof finalResponse.output[0] === "string" + ) { + if (outputType === "json") { + return { ...finalResponse }; + } + if (outputType === "url") { + return finalResponse.output[0]; + } + const urlResponse = await fetch(finalResponse.output[0]); + return await urlResponse.blob(); + } + + throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-image API"); + } } export class ReplicateTextToSpeechTask extends ReplicateTask { @@ -275,51 +267,43 @@ export class ReplicateTextToSpeechTask extends ReplicateTask { return payload; } - override async getResponse( - response: ReplicateOutput | Blob, - url?: string, - headers?: HeadersInit - ): Promise { - const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput | Blob; - - if (finalResponse instanceof Blob) { - return finalResponse; - } - if (finalResponse && typeof finalResponse === "object") { - if ("output" in finalResponse) { - if (typeof finalResponse.output === "string") { - const urlResponse = await fetch(finalResponse.output); - return await urlResponse.blob(); - } else if (Array.isArray(finalResponse.output)) { - const urlResponse = await fetch(finalResponse.output[0]); - return await urlResponse.blob(); - } - } - } - throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-speech API"); - } + override async getResponse(response: ReplicateOutput | Blob, url?: string, headers?: HeadersInit): Promise { + const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput | Blob; + + if (finalResponse instanceof Blob) { + return finalResponse; + } + if (finalResponse && typeof finalResponse === "object") { + if ("output" in finalResponse) { + if (typeof finalResponse.output === "string") { + const urlResponse = await fetch(finalResponse.output); + return await urlResponse.blob(); + } else if (Array.isArray(finalResponse.output)) { + const urlResponse = await fetch(finalResponse.output[0]); + return await urlResponse.blob(); + } + } + } + throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-speech API"); + } } export class ReplicateTextToVideoTask extends ReplicateTask implements TextToVideoTaskHelper { - override async getResponse( - response: ReplicateOutput | Blob, - url?: string, - headers?: HeadersInit - ): Promise { - const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput; - if ( - typeof finalResponse === "object" && - !!finalResponse && - "output" in finalResponse && - typeof finalResponse.output === "string" && - isUrl(finalResponse.output) - ) { - const urlResponse = await fetch(finalResponse.output); - return await urlResponse.blob(); - } - - throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-video API"); - } + override async getResponse(response: ReplicateOutput | Blob, url?: string, headers?: HeadersInit): Promise { + const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput; + if ( + typeof finalResponse === "object" && + !!finalResponse && + "output" in finalResponse && + typeof finalResponse.output === "string" && + isUrl(finalResponse.output) + ) { + const urlResponse = await fetch(finalResponse.output); + return await urlResponse.blob(); + } + + throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-video API"); + } } export class ReplicateAutomaticSpeechRecognitionTask @@ -355,22 +339,22 @@ export class ReplicateAutomaticSpeechRecognitionTask }; } - override async getResponse( - response: ReplicateOutput | Blob, - url?: string, - headers?: HeadersInit - ): Promise { - const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput; - if (typeof finalResponse?.output === "string") return { text: finalResponse.output }; - if (Array.isArray(finalResponse?.output) && typeof finalResponse.output[0] === "string") - return { text: finalResponse.output[0] }; - - const out = finalResponse?.output as - | undefined - | { - transcription?: string; - translation?: string; - txt_file?: string; + override async getResponse( + response: ReplicateOutput | Blob, + url?: string, + headers?: HeadersInit + ): Promise { + const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput; + if (typeof finalResponse?.output === "string") return { text: finalResponse.output }; + if (Array.isArray(finalResponse?.output) && typeof finalResponse.output[0] === "string") + return { text: finalResponse.output[0] }; + + const out = finalResponse?.output as + | undefined + | { + transcription?: string; + translation?: string; + txt_file?: string; }; if (out && typeof out === "object") { if (typeof out.transcription === "string") return { text: out.transcription }; @@ -416,35 +400,31 @@ export class ReplicateImageToImageTask extends ReplicateTask implements ImageToI }; } - override async getResponse( - response: ReplicateOutput | Blob, - url?: string, - headers?: HeadersInit - ): Promise { - const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput; - if ( - typeof finalResponse === "object" && - !!finalResponse && - "output" in finalResponse && - Array.isArray(finalResponse.output) && - finalResponse.output.length > 0 && - typeof finalResponse.output[0] === "string" - ) { - const urlResponse = await fetch(finalResponse.output[0]); - return await urlResponse.blob(); - } - - if ( - typeof finalResponse === "object" && - !!finalResponse && - "output" in finalResponse && - typeof finalResponse.output === "string" && - isUrl(finalResponse.output) - ) { - const urlResponse = await fetch(finalResponse.output); - return await urlResponse.blob(); - } - - throw new InferenceClientProviderOutputError("Received malformed response from Replicate image-to-image API"); - } + override async getResponse(response: ReplicateOutput | Blob, url?: string, headers?: HeadersInit): Promise { + const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput; + if ( + typeof finalResponse === "object" && + !!finalResponse && + "output" in finalResponse && + Array.isArray(finalResponse.output) && + finalResponse.output.length > 0 && + typeof finalResponse.output[0] === "string" + ) { + const urlResponse = await fetch(finalResponse.output[0]); + return await urlResponse.blob(); + } + + if ( + typeof finalResponse === "object" && + !!finalResponse && + "output" in finalResponse && + typeof finalResponse.output === "string" && + isUrl(finalResponse.output) + ) { + const urlResponse = await fetch(finalResponse.output); + return await urlResponse.blob(); + } + + throw new InferenceClientProviderOutputError("Received malformed response from Replicate image-to-image API"); + } }