Skip to content

Commit 064aa3b

Browse files
authored
Fix tmp dir bug (huggingface#285)
1 parent 4960efc commit 064aa3b

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

shark/torch_mlir_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from torch_mlir.ir import StringAttr
1616
import torch_mlir
1717
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
18+
import tempfile
19+
from shark.parser import shark_args
1820

1921

2022
def get_module_name_for_asm_dump(module):
@@ -62,6 +64,8 @@ def get_torch_mlir_module(
6264
if jit_trace:
6365
ignore_traced_shapes = True
6466

67+
tempfile.tempdir = shark_args.repro_dir
68+
6569
module = torch_mlir.compile(
6670
module,
6771
input,

0 commit comments

Comments
 (0)