Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/unitytrainers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,16 @@ def create_continuous_state_encoder(self, h_size, activation, num_layers):
kernel_initializer=c_layers.variance_scaling_initializer(1.0))
return hidden

def create_visual_encoder(self, h_size, activation, num_layers):
def create_visual_encoder(self, image_input, h_size, activation, num_layers):
"""
Builds a set of visual (CNN) encoders.
:param image_input: The placeholder for the image input to use.
:param h_size: Hidden layer size.
:param activation: What type of activation function to use for layers.
:param num_layers: number of hidden layers to create.
:return: List of hidden layer tensors.
"""
conv1 = tf.layers.conv2d(self.visual_in[-1], 16, kernel_size=[8, 8], strides=[4, 4],
conv1 = tf.layers.conv2d(image_input, 16, kernel_size=[8, 8], strides=[4, 4],
activation=tf.nn.elu)
conv2 = tf.layers.conv2d(conv1, 32, kernel_size=[4, 4], strides=[2, 2],
activation=tf.nn.elu)
Expand Down Expand Up @@ -136,7 +137,7 @@ def create_new_obs(self, num_streams, h_size, num_layers):
hidden_state, hidden_visual = None, None
if brain.number_visual_observations > 0:
for j in range(brain.number_visual_observations):
encoded_visual = self.create_visual_encoder(h_size, activation_fn, num_layers)
encoded_visual = self.create_visual_encoder(self.visual_in[j], h_size, activation_fn, num_layers)
visual_encoders.append(encoded_visual)
hidden_visual = tf.concat(visual_encoders, axis=1)
if brain.vector_observation_space_size > 0:
Expand Down