Skip to content

Commit 2613c88

Browse files
authored
Update import_transformers_js to accept a URL not only a version string (#164)
* Update import_transformers_js to accept a URL not only a version string * Add a test case * Format
1 parent cfdb941 commit 2613c88

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

pyodide-e2e/src/tests/import.test.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,13 @@ describe("import_transformers_js", () => {
3838
);
3939
const version270 = await pyodide.runPythonAsync(`transformers.env.version`);
4040
expect(version270).toBe("2.7.0");
41+
42+
await pyodide.runPythonAsync(
43+
`transformers = await import_transformers_js("https://cdn.jsdelivr.net/npm/@xenova/[email protected]")`,
44+
);
45+
const version2177 = await pyodide.runPythonAsync(
46+
`transformers.env.version`,
47+
);
48+
expect(version2177).toBe("2.17.1");
4149
});
4250
});

transformers_js_py/proxies.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,24 @@ def wrap_or_unwrap_proxy_object(obj):
201201
return obj
202202

203203

204-
async def import_transformers_js(version: str = "latest"):
204+
async def import_transformers_js(version_or_url: str = "latest"):
205205
loadTransformersJsFn = pyodide.code.run_js(
206206
"""
207-
async (version) => {
207+
async (versionOrUrl) => {
208+
function getTransformersJsUrl() {
209+
try {
210+
return new URL(versionOrUrl);
211+
} catch {
212+
const version = versionOrUrl;
213+
return new URL('https://cdn.jsdelivr.net/npm/@xenova/transformers@' + version);
214+
}
215+
}
216+
208217
const isBrowserMainThread = typeof window !== 'undefined';
209218
const isWorker = typeof WorkerGlobalScope !== 'undefined' && self instanceof WorkerGlobalScope;
210219
const isBrowser = isBrowserMainThread || isWorker;
211-
const transformers = await import(isBrowser ? 'https://cdn.jsdelivr.net/npm/@xenova/transformers@' + version : '@xenova/transformers');
220+
221+
const transformers = await import(isBrowser ? getTransformersJsUrl() : '@xenova/transformers');
212222
213223
transformers.env.allowLocalModels = false;
214224
@@ -219,5 +229,5 @@ async def import_transformers_js(version: str = "latest"):
219229
""" # noqa: E501
220230
)
221231
global _TRANSFORMERS_JS
222-
_TRANSFORMERS_JS = await loadTransformersJsFn(version)
232+
_TRANSFORMERS_JS = await loadTransformersJsFn(version_or_url)
223233
return TjsModuleProxy(_TRANSFORMERS_JS)

0 commit comments

Comments
 (0)