@@ -162,26 +162,33 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten
162162 """
163163 dists = self ._get_dists (inputs , masks )
164164 continuous_out , discrete_out , action_out_deprecated = None , None , None
165- deter_continuous_out , deter_discrete_out = None , None # deterministic actions
165+ deterministic_continuous_out , deterministic_discrete_out = (
166+ None ,
167+ None ,
168+ ) # deterministic actions
166169 if self .action_spec .continuous_size > 0 and dists .continuous is not None :
167170 continuous_out = dists .continuous .exported_model_output ()
168171 action_out_deprecated = continuous_out
169- deter_continuous_out = dists .continuous .deterministic_sample ()
172+ deterministic_continuous_out = dists .continuous .deterministic_sample ()
170173 if self ._clip_action_on_export :
171174 continuous_out = torch .clamp (continuous_out , - 3 , 3 ) / 3
172175 action_out_deprecated = continuous_out
173- deter_continuous_out = torch .clamp (deter_continuous_out , - 3 , 3 ) / 3
176+ deterministic_continuous_out = (
177+ torch .clamp (deterministic_continuous_out , - 3 , 3 ) / 3
178+ )
174179 if self .action_spec .discrete_size > 0 and dists .discrete is not None :
175180 discrete_out_list = [
176181 discrete_dist .exported_model_output ()
177182 for discrete_dist in dists .discrete
178183 ]
179184 discrete_out = torch .cat (discrete_out_list , dim = 1 )
180185 action_out_deprecated = torch .cat (discrete_out_list , dim = 1 )
181- deter_discrete_out_list = [
186+ deterministic_discrete_out_list = [
182187 discrete_dist .deterministic_sample () for discrete_dist in dists .discrete
183188 ]
184- deter_discrete_out = torch .cat (deter_discrete_out_list , dim = 1 )
189+ deterministic_discrete_out = torch .cat (
190+ deterministic_discrete_out_list , dim = 1
191+ )
185192
186193 # deprecated action field does not support hybrid action
187194 if self .action_spec .continuous_size > 0 and self .action_spec .discrete_size > 0 :
@@ -190,8 +197,8 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten
190197 continuous_out ,
191198 discrete_out ,
192199 action_out_deprecated ,
193- deter_continuous_out ,
194- deter_discrete_out ,
200+ deterministic_continuous_out ,
201+ deterministic_discrete_out ,
195202 )
196203
197204 def forward (
0 commit comments