diff --git a/ml-agents/mlagents/trainers/torch/encoders.py b/ml-agents/mlagents/trainers/torch/encoders.py index 676136fff4..32cc384b86 100644 --- a/ml-agents/mlagents/trainers/torch/encoders.py +++ b/ml-agents/mlagents/trainers/torch/encoders.py @@ -257,8 +257,9 @@ def __init__( layers.append(ResNetBlock(channel)) last_channel = channel layers.append(Swish()) + self.final_flat_size = n_channels[-1] * height * width self.dense = linear_layer( - n_channels[-1] * height * width, + self.final_flat_size, output_size, kernel_init=Initialization.KaimingHeNormal, kernel_gain=1.41, # Use ReLU gain @@ -268,7 +269,6 @@ def __init__( def forward(self, visual_obs: torch.Tensor) -> torch.Tensor: if not exporting_to_onnx.is_exporting(): visual_obs = visual_obs.permute([0, 3, 1, 2]) - batch_size = visual_obs.shape[0] hidden = self.sequential(visual_obs) - before_out = hidden.reshape(batch_size, -1) + before_out = hidden.reshape(-1, self.final_flat_size) return torch.relu(self.dense(before_out))