Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pyodide-e2e/src/tests/import.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,13 @@ describe("import_transformers_js", () => {
);
const version270 = await pyodide.runPythonAsync(`transformers.env.version`);
expect(version270).toBe("2.7.0");

await pyodide.runPythonAsync(
`transformers = await import_transformers_js("https://cdn.jsdelivr.net/npm/@xenova/[email protected]")`,
);
const version2177 = await pyodide.runPythonAsync(
`transformers.env.version`,
);
expect(version2177).toBe("2.17.1");
});
});
18 changes: 14 additions & 4 deletions transformers_js_py/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,24 @@ def wrap_or_unwrap_proxy_object(obj):
return obj


async def import_transformers_js(version: str = "latest"):
async def import_transformers_js(version_or_url: str = "latest"):
loadTransformersJsFn = pyodide.code.run_js(
"""
async (version) => {
async (versionOrUrl) => {
function getTransformersJsUrl() {
try {
return new URL(versionOrUrl);
} catch {
const version = versionOrUrl;
return new URL('https://cdn.jsdelivr.net/npm/@xenova/transformers@' + version);
}
}

const isBrowserMainThread = typeof window !== 'undefined';
const isWorker = typeof WorkerGlobalScope !== 'undefined' && self instanceof WorkerGlobalScope;
const isBrowser = isBrowserMainThread || isWorker;
const transformers = await import(isBrowser ? 'https://cdn.jsdelivr.net/npm/@xenova/transformers@' + version : '@xenova/transformers');

const transformers = await import(isBrowser ? getTransformersJsUrl() : '@xenova/transformers');

transformers.env.allowLocalModels = false;

Expand All @@ -219,5 +229,5 @@ async def import_transformers_js(version: str = "latest"):
""" # noqa: E501
)
global _TRANSFORMERS_JS
_TRANSFORMERS_JS = await loadTransformersJsFn(version)
_TRANSFORMERS_JS = await loadTransformersJsFn(version_or_url)
return TjsModuleProxy(_TRANSFORMERS_JS)