Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 190 additions & 43 deletions packages/inference/src/providers/replicate.ts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have to remove Prefer: wait from the headers (defined in prepareHeaders) now that all the tasks are using async polling

Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,51 @@ export interface ReplicateOutput {
output?: string | string[];
}

type ReplicatePredictionStatus = "starting" | "processing" | "succeeded" | "failed" | "canceled" | "queued";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe this gives tighter typings and better autocomplete support

Suggested change
type ReplicatePredictionStatus = "starting" | "processing" | "succeeded" | "failed" | "canceled" | "queued";
const REPLICATE_STATUSES = [
"starting",
"processing",
"succeeded",
"failed",
"canceled",
"queued",
] as const;
type ReplicatePredictionStatus = (typeof REPLICATE_STATUSES)[number];


interface ReplicateAsyncResponse extends ReplicateOutput {
id?: string;
status?: ReplicatePredictionStatus;
error?: unknown;
urls?: {
get?: string;
};
}

const POLLING_INTERVAL_MS = 1_000;

function headersInitToRecord(headers?: HeadersInit): Record<string, string> {
if (!headers) {
return {};
}
if (headers instanceof Headers) {
return Object.fromEntries(headers.entries());
}
if (Array.isArray(headers)) {
return Object.fromEntries(headers);
}
return { ...headers };
}

function getErrorMessage(error: unknown): string | undefined {
if (!error) {
return undefined;
}
if (typeof error === "string") {
return error;
}
if (typeof error === "object" && "message" in error && typeof error.message === "string") {
return error.message;
}
return undefined;
}

async function sleep(ms: number): Promise<void> {
await new Promise((resolve) => {
setTimeout(resolve, ms);
});
}

abstract class ReplicateTask extends TaskProviderHelper {
constructor(url?: string) {
super("replicate", url || "https://api.replicate.com");
Expand Down Expand Up @@ -69,6 +114,97 @@ abstract class ReplicateTask extends TaskProviderHelper {
}
return `${baseUrl}/v1/models/${params.model}/predictions`;
}

protected async ensureFinalResponse(
response: ReplicateOutput | Blob | ReplicateAsyncResponse,
requestUrl?: string,
headers?: HeadersInit
): Promise<ReplicateOutput | Blob> {
if (response instanceof Blob) {
return response;
}

if (!response || typeof response !== "object") {
return response as ReplicateOutput;
}

const status = "status" in response ? response.status : undefined;

if (!status || status === "succeeded") {
return response as ReplicateOutput;
}

if (status === "failed" || status === "canceled") {
const message = getErrorMessage((response as ReplicateAsyncResponse).error);
throw new InferenceClientProviderOutputError(`Replicate prediction ${status}${message ? `: ${message}` : ""}`);
}

const pollUrl = this.getPollUrl(response as ReplicateAsyncResponse, requestUrl);
if (!pollUrl) {
throw new InferenceClientProviderOutputError(
"Received incomplete response from Replicate API: missing polling URL"
);
}

const headerRecord = headersInitToRecord(headers);
const pollHeaders: Record<string, string> = {};
if (headerRecord.Authorization) {
pollHeaders.Authorization = headerRecord.Authorization;
}
pollHeaders.Accept = "application/json";

// Poll the prediction endpoint until completion
while (true) {
await sleep(POLLING_INTERVAL_MS);
const pollResponse = await fetch(pollUrl, {
method: "GET",
headers: pollHeaders,
});

if (!pollResponse.ok) {
throw new InferenceClientProviderOutputError(
`Failed to poll Replicate prediction status: HTTP ${pollResponse.status}`
);
}

const prediction = (await pollResponse.json()) as ReplicateAsyncResponse;
const predictionStatus = prediction.status;

if (!predictionStatus || predictionStatus === "succeeded") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it intentional to treat "no status" as a successful response?

return prediction as ReplicateOutput;
}

if (predictionStatus === "failed" || predictionStatus === "canceled") {
const message = getErrorMessage(prediction.error);
throw new InferenceClientProviderOutputError(
`Replicate prediction ${predictionStatus}${message ? `: ${message}` : ""}`
);
}
}
}

private getPollUrl(response: ReplicateAsyncResponse, requestUrl?: string): string | undefined {
if (response.urls && typeof response.urls === "object" && typeof response.urls.get === "string") {
return response.urls.get;
}

if (!response.id || !requestUrl) {
return undefined;
}

try {
const url = new URL(requestUrl);
const pathname = url.pathname.replace(/\/$/, "");
if (pathname.endsWith("/predictions")) {
url.pathname = `${pathname}/${response.id}`;
return url.toString();
}
} catch {
return undefined;
}

return undefined;
}
}

export class ReplicateTextToImageTask extends ReplicateTask implements TextToImageTaskHelper {
Expand All @@ -94,21 +230,22 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma
outputType?: "url" | "blob" | "json"
): Promise<string | Blob | Record<string, unknown>> {
void url;
void headers;
const finalResponse = (await this.ensureFinalResponse(res, url, headers)) as ReplicateOutput;

if (
typeof res === "object" &&
"output" in res &&
Array.isArray(res.output) &&
res.output.length > 0 &&
typeof res.output[0] === "string"
typeof finalResponse === "object" &&
"output" in finalResponse &&
Array.isArray(finalResponse.output) &&
finalResponse.output.length > 0 &&
typeof finalResponse.output[0] === "string"
) {
if (outputType === "json") {
return { ...res };
return { ...finalResponse };
}
if (outputType === "url") {
return res.output[0];
return finalResponse.output[0];
}
const urlResponse = await fetch(res.output[0]);
const urlResponse = await fetch(finalResponse.output[0]);
return await urlResponse.blob();
}

Expand All @@ -130,17 +267,19 @@ export class ReplicateTextToSpeechTask extends ReplicateTask {
return payload;
}

override async getResponse(response: ReplicateOutput): Promise<Blob> {
if (response instanceof Blob) {
return response;
override async getResponse(response: ReplicateOutput | Blob, url?: string, headers?: HeadersInit): Promise<Blob> {
const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput | Blob;

if (finalResponse instanceof Blob) {
return finalResponse;
}
if (response && typeof response === "object") {
if ("output" in response) {
if (typeof response.output === "string") {
const urlResponse = await fetch(response.output);
if (finalResponse && typeof finalResponse === "object") {
if ("output" in finalResponse) {
if (typeof finalResponse.output === "string") {
const urlResponse = await fetch(finalResponse.output);
return await urlResponse.blob();
} else if (Array.isArray(response.output)) {
const urlResponse = await fetch(response.output[0]);
} else if (Array.isArray(finalResponse.output)) {
const urlResponse = await fetch(finalResponse.output[0]);
return await urlResponse.blob();
}
}
Expand All @@ -150,15 +289,16 @@ export class ReplicateTextToSpeechTask extends ReplicateTask {
}

export class ReplicateTextToVideoTask extends ReplicateTask implements TextToVideoTaskHelper {
override async getResponse(response: ReplicateOutput): Promise<Blob> {
override async getResponse(response: ReplicateOutput | Blob, url?: string, headers?: HeadersInit): Promise<Blob> {
const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput;
if (
typeof response === "object" &&
!!response &&
"output" in response &&
typeof response.output === "string" &&
isUrl(response.output)
typeof finalResponse === "object" &&
!!finalResponse &&
"output" in finalResponse &&
typeof finalResponse.output === "string" &&
isUrl(finalResponse.output)
) {
const urlResponse = await fetch(response.output);
const urlResponse = await fetch(finalResponse.output);
return await urlResponse.blob();
}

Expand Down Expand Up @@ -199,11 +339,17 @@ export class ReplicateAutomaticSpeechRecognitionTask
};
}

override async getResponse(response: ReplicateOutput): Promise<AutomaticSpeechRecognitionOutput> {
if (typeof response?.output === "string") return { text: response.output };
if (Array.isArray(response?.output) && typeof response.output[0] === "string") return { text: response.output[0] };
override async getResponse(
response: ReplicateOutput | Blob,
url?: string,
headers?: HeadersInit
): Promise<AutomaticSpeechRecognitionOutput> {
const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput;
if (typeof finalResponse?.output === "string") return { text: finalResponse.output };
if (Array.isArray(finalResponse?.output) && typeof finalResponse.output[0] === "string")
return { text: finalResponse.output[0] };

const out = response?.output as
const out = finalResponse?.output as
| undefined
| {
transcription?: string;
Expand Down Expand Up @@ -254,27 +400,28 @@ export class ReplicateImageToImageTask extends ReplicateTask implements ImageToI
};
}

override async getResponse(response: ReplicateOutput): Promise<Blob> {
override async getResponse(response: ReplicateOutput | Blob, url?: string, headers?: HeadersInit): Promise<Blob> {
const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput;
if (
typeof response === "object" &&
!!response &&
"output" in response &&
Array.isArray(response.output) &&
response.output.length > 0 &&
typeof response.output[0] === "string"
typeof finalResponse === "object" &&
!!finalResponse &&
"output" in finalResponse &&
Array.isArray(finalResponse.output) &&
finalResponse.output.length > 0 &&
typeof finalResponse.output[0] === "string"
) {
const urlResponse = await fetch(response.output[0]);
const urlResponse = await fetch(finalResponse.output[0]);
return await urlResponse.blob();
}

if (
typeof response === "object" &&
!!response &&
"output" in response &&
typeof response.output === "string" &&
isUrl(response.output)
typeof finalResponse === "object" &&
!!finalResponse &&
"output" in finalResponse &&
typeof finalResponse.output === "string" &&
isUrl(finalResponse.output)
) {
const urlResponse = await fetch(response.output);
const urlResponse = await fetch(finalResponse.output);
return await urlResponse.blob();
}

Expand Down