@@ -75,21 +75,25 @@ def _torch_mlir(self, is_dynamic, tracing_required):
7575 self .module , self .inputs , is_dynamic , tracing_required
7676 )
7777
78- def _tf_mlir (self , func_name ):
78+ def _tf_mlir (self , func_name , save_dir = "./shark_tmp/" ):
7979 from iree .compiler import tf as tfc
8080
8181 return tfc .compile_module (
82- self .module , exported_names = [func_name ], import_only = True
82+ self .module ,
83+ exported_names = [func_name ],
84+ import_only = True ,
85+ output_file = save_dir ,
8386 )
8487
85- def _tflite_mlir (self , func_name ):
88+ def _tflite_mlir (self , func_name , save_dir = "./shark_tmp/" ):
8689 from iree .compiler import tflite as tflitec
8790 from shark .iree_utils ._common import IREE_TARGET_MAP
8891
8992 self .mlir_model = tflitec .compile_file (
9093 self .raw_model_file , # in tflite, it is a path to .tflite file, not a tflite interpreter
9194 input_type = "tosa" ,
9295 import_only = True ,
96+ output_file = save_dir ,
9397 )
9498 return self .mlir_model
9599
@@ -99,6 +103,7 @@ def import_mlir(
99103 is_dynamic = False ,
100104 tracing_required = False ,
101105 func_name = "forward" ,
106+ save_dir = "./shark_tmp/" ,
102107 ):
103108 if self .frontend in ["torch" , "pytorch" ]:
104109 if self .inputs == None :
@@ -108,10 +113,10 @@ def import_mlir(
108113 sys .exit (1 )
109114 return self ._torch_mlir (is_dynamic , tracing_required ), func_name
110115 if self .frontend in ["tf" , "tensorflow" ]:
111- return self ._tf_mlir (func_name ), func_name
116+ return self ._tf_mlir (func_name , save_dir ), func_name
112117 if self .frontend in ["tflite" , "tf-lite" ]:
113118 func_name = "main"
114- return self ._tflite_mlir (func_name ), func_name
119+ return self ._tflite_mlir (func_name , save_dir ), func_name
115120
116121 # Converts the frontend specific tensors into np array.
117122 def convert_to_numpy (self , array_tuple : tuple ):
@@ -130,20 +135,22 @@ def save_data(
130135 outputs_name = "golden_out.npz"
131136 func_file_name = "function_name"
132137 model_name_mlir = model_name + "_" + self .frontend + ".mlir"
133- inputs = [x .cpu ().detach () for x in inputs ]
138+ try :
139+ inputs = [x .cpu ().detach () for x in inputs ]
140+ except AttributeError :
141+ try :
142+ inputs = [x .numpy () for x in inputs ]
143+ except AttributeError :
144+ inputs = [x for x in inputs ]
134145 np .savez (os .path .join (dir , inputs_name ), * inputs )
135146 np .savez (os .path .join (dir , outputs_name ), * outputs )
136147 np .save (os .path .join (dir , func_file_name ), np .array (func_name ))
137148
138149 mlir_str = mlir_data
139150 if self .frontend == "torch" :
140151 mlir_str = mlir_data .operation .get_asm ()
141- elif self .frontend == "tf" :
142- mlir_str = mlir_data .decode ("latin-1" )
143- elif self .frontend == "tflite" :
144- mlir_str = mlir_data .decode ("latin-1" )
145- with open (os .path .join (dir , model_name_mlir ), "w" ) as mlir_file :
146- mlir_file .write (mlir_str )
152+ with open (os .path .join (dir , model_name_mlir ), "w" ) as mlir_file :
153+ mlir_file .write (mlir_str )
147154
148155 return
149156
@@ -160,9 +167,13 @@ def import_debug(
160167 f"There is no input provided: { self .inputs } , please provide inputs or simply run import_mlir."
161168 )
162169 sys .exit (1 )
163-
170+ model_name_mlir = model_name + "_" + self .frontend + ".mlir"
171+ artifact_path = os .path .join (dir , model_name_mlir )
164172 imported_mlir = self .import_mlir (
165- is_dynamic , tracing_required , func_name
173+ is_dynamic ,
174+ tracing_required ,
175+ func_name ,
176+ save_dir = artifact_path ,
166177 )
167178 # TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
168179 # TODO: Check for multiple outputs.
0 commit comments