@@ -64,6 +64,7 @@ def __init__(
6464 act_size : List [int ],
6565 reparameterize : bool = False ,
6666 tanh_squash : bool = False ,
67+ condition_sigma : bool = True ,
6768 log_sigma_min : float = - 20 ,
6869 log_sigma_max : float = 2 ,
6970 ):
@@ -79,7 +80,11 @@ def __init__(
7980 :param log_sigma_max: Maximum log standard deviation to clip by.
8081 """
8182 encoded = self ._create_mu_log_sigma (
82- logits , act_size , log_sigma_min , log_sigma_max
83+ logits ,
84+ act_size ,
85+ log_sigma_min ,
86+ log_sigma_max ,
87+ condition_sigma = condition_sigma ,
8388 )
8489 self ._sampled_policy = self ._create_sampled_policy (encoded )
8590 if not reparameterize :
@@ -101,6 +106,7 @@ def _create_mu_log_sigma(
101106 act_size : List [int ],
102107 log_sigma_min : float ,
103108 log_sigma_max : float ,
109+ condition_sigma : bool ,
104110 ) -> "GaussianDistribution.MuSigmaTensors" :
105111
106112 mu = tf .layers .dense (
@@ -112,14 +118,22 @@ def _create_mu_log_sigma(
112118 reuse = tf .AUTO_REUSE ,
113119 )
114120
115- # Policy-dependent log_sigma_sq
116- log_sigma = tf .layers .dense (
117- logits ,
118- act_size [0 ],
119- activation = None ,
120- name = "log_std" ,
121- kernel_initializer = ModelUtils .scaled_init (0.01 ),
122- )
121+ if condition_sigma :
122+ # Policy-dependent log_sigma_sq
123+ log_sigma = tf .layers .dense (
124+ logits ,
125+ act_size [0 ],
126+ activation = None ,
127+ name = "log_std" ,
128+ kernel_initializer = ModelUtils .scaled_init (0.01 ),
129+ )
130+ else :
131+ log_sigma = tf .get_variable (
132+ "log_std" ,
133+ [act_size [0 ]],
134+ dtype = tf .float32 ,
135+ initializer = tf .zeros_initializer (),
136+ )
123137 log_sigma = tf .clip_by_value (log_sigma , log_sigma_min , log_sigma_max )
124138 sigma = tf .exp (log_sigma )
125139 return self .MuSigmaTensors (mu , log_sigma , sigma )
@@ -155,8 +169,8 @@ def _do_squash_correction_for_tanh(self, probs, squashed_policy):
155169 """
156170 Adjust probabilities for squashed sample before output
157171 """
158- probs -= tf .log (1 - squashed_policy ** 2 + EPSILON )
159- return probs
172+ adjusted_probs = probs - tf .log (1 - squashed_policy ** 2 + EPSILON )
173+ return adjusted_probs
160174
161175 @property
162176 def total_log_probs (self ) -> tf .Tensor :
0 commit comments