Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ml-agents-envs/mlagents_envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Dict, List, Optional, Any

import mlagents_envs
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.side_channel import SideChannel, IncomingMessage

from mlagents_envs.base_env import (
BaseEnv,
Expand Down Expand Up @@ -498,7 +498,8 @@ def _parse_side_channel_message(
"sending side channel data properly.".format(channel_id)
)
if channel_id in side_channels:
side_channels[channel_id].on_message_received(message_data)
incoming_message = IncomingMessage(message_data)
side_channels[channel_id].on_message_received(incoming_message)
else:
logger.warning(
"Unknown side channel data received. Channel type "
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.side_channel import (
SideChannel,
OutgoingMessage,
IncomingMessage,
)
from mlagents_envs.exception import UnityCommunicationException
import struct
import uuid
from typing import NamedTuple

Expand Down Expand Up @@ -31,7 +34,7 @@ class EngineConfigurationChannel(SideChannel):
def __init__(self) -> None:
super().__init__(uuid.UUID("e951342c-4f7e-11ea-b238-784f4387d1f7"))

def on_message_received(self, data: bytes) -> None:
def on_message_received(self, msg: IncomingMessage) -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that
Expand Down Expand Up @@ -65,18 +68,16 @@ def set_configuration_parameters(
:param target_frame_rate: Instructs simulation to try to render at a
specified frame rate. Default -1.
"""
data = bytearray()
data += struct.pack("<i", width)
data += struct.pack("<i", height)
data += struct.pack("<i", quality_level)
data += struct.pack("<f", time_scale)
data += struct.pack("<i", target_frame_rate)
super().queue_message_to_send(data)
msg = OutgoingMessage()
msg.write_int32(width)
msg.write_int32(height)
msg.write_int32(quality_level)
msg.write_float32(time_scale)
msg.write_int32(target_frame_rate)
super().queue_message_to_send(msg)

def set_configuration(self, config: EngineConfig) -> None:
"""
Sets the engine configuration. Takes as input an EngineConfig.
"""
data = bytearray()
data += struct.pack("<iiifi", *config)
super().queue_message_to_send(data)
self.set_configuration_parameters(**config._asdict())
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from mlagents_envs.side_channel.side_channel import SideChannel
import struct
from mlagents_envs.side_channel.side_channel import (
SideChannel,
IncomingMessage,
OutgoingMessage,
)
import uuid
from typing import Dict, Tuple, Optional, List
from typing import Dict, Optional, List


class FloatPropertiesChannel(SideChannel):
Expand All @@ -17,15 +20,14 @@ def __init__(self, channel_id: uuid.UUID = None) -> None:
channel_id = uuid.UUID(("60ccf7d0-4f7e-11ea-b238-784f4387d1f7"))
super().__init__(channel_id)

def on_message_received(self, data: bytes) -> None:
def on_message_received(self, msg: IncomingMessage) -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that
SideChannel.
Note that Python should never receive an engine configuration from
Unity
"""
k, v = self.deserialize_float_prop(data)
k = msg.read_string()
v = msg.read_float32()
self._float_properties[k] = v

def set_property(self, key: str, value: float) -> None:
Expand All @@ -35,7 +37,10 @@ def set_property(self, key: str, value: float) -> None:
:param value: The float value of the property.
"""
self._float_properties[key] = value
super().queue_message_to_send(self.serialize_float_prop(key, value))
msg = OutgoingMessage()
msg.write_string(key)
msg.write_float32(value)
super().queue_message_to_send(msg)

def get_property(self, key: str) -> Optional[float]:
"""
Expand All @@ -59,22 +64,3 @@ def get_property_dict_copy(self) -> Dict[str, float]:
:return:
"""
return dict(self._float_properties)

@staticmethod
def serialize_float_prop(key: str, value: float) -> bytearray:
result = bytearray()
encoded_key = key.encode("ascii")
result += struct.pack("<i", len(encoded_key))
result += encoded_key
result += struct.pack("<f", value)
return result

@staticmethod
def deserialize_float_prop(data: bytes) -> Tuple[str, float]:
offset = 0
encoded_key_len = struct.unpack_from("<i", data, offset)[0]
offset = offset + 4
key = data[offset : offset + encoded_key_len].decode("ascii")
offset = offset + encoded_key_len
value = struct.unpack_from("<f", data, offset)[0]
return key, value
14 changes: 10 additions & 4 deletions ml-agents-envs/mlagents_envs/side_channel/raw_bytes_channel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.side_channel import (
SideChannel,
IncomingMessage,
OutgoingMessage,
)
from typing import List
import uuid

Expand All @@ -13,13 +17,13 @@ def __init__(self, channel_id: uuid.UUID):
self._received_messages: List[bytes] = []
super().__init__(channel_id)

def on_message_received(self, data: bytes) -> None:
def on_message_received(self, msg: IncomingMessage) -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that
SideChannel.
"""
self._received_messages.append(data)
self._received_messages.append(msg.get_raw_bytes())

def get_and_clear_received_messages(self) -> List[bytes]:
"""
Expand All @@ -34,4 +38,6 @@ def send_raw_data(self, data: bytearray) -> None:
Queues a message to be sent by the environment at the next call to
step.
"""
super().queue_message_to_send(data)
msg = OutgoingMessage()
msg.set_raw_bytes(data)
super().queue_message_to_send(msg)
59 changes: 56 additions & 3 deletions ml-agents-envs/mlagents_envs/side_channel/side_channel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from abc import ABC, abstractmethod
from typing import List
import uuid
import struct

import logging

logger = logging.getLogger(__name__)


class SideChannel(ABC):
Expand All @@ -16,15 +21,15 @@ def __init__(self, channel_id):
self._channel_id: uuid.UUID = channel_id
self.message_queue: List[bytearray] = []

def queue_message_to_send(self, data: bytearray) -> None:
def queue_message_to_send(self, msg: "OutgoingMessage") -> None:
"""
Queues a message to be sent by the environment at the next call to
step.
"""
self.message_queue.append(data)
self.message_queue.append(msg.buffer)

@abstractmethod
def on_message_received(self, data: bytes) -> None:
def on_message_received(self, msg: "IncomingMessage") -> None:
"""
Is called by the environment to the side channel. Can be called
multiple times per step if multiple messages are meant for that
Expand All @@ -39,3 +44,51 @@ def channel_id(self) -> uuid.UUID:
processed in the environment.
"""
return self._channel_id


class OutgoingMessage:
def __init__(self):
self.buffer = bytearray()

def write_int32(self, i: int) -> None:
self.buffer += struct.pack("<i", i)

def write_float32(self, f: float) -> None:
self.buffer += struct.pack("<f", f)

def write_string(self, s: str) -> None:
encoded_key = s.encode("ascii")
self.write_int32(len(encoded_key))
self.buffer += encoded_key

def set_raw_bytes(self, buffer: bytearray) -> None:
if self.buffer:
logger.warning(
"Called set_raw_bytes but the message already has been written to. This will overwrite data."
)
self.buffer = bytearray(buffer)


class IncomingMessage:
def __init__(self, buffer: bytes, offset: int = 0):
self.buffer = buffer
self.offset = offset

def read_int32(self) -> int:
val = struct.unpack_from("<i", self.buffer, self.offset)[0]
self.offset += 4
return val

def read_float32(self) -> float:
val = struct.unpack_from("<f", self.buffer, self.offset)[0]
self.offset += 4
return val

def read_string(self) -> str:
encoded_str_len = self.read_int32()
val = self.buffer[self.offset : self.offset + encoded_str_len].decode("ascii")
self.offset += encoded_str_len
return val

def get_raw_bytes(self) -> bytes:
return bytearray(self.buffer)
17 changes: 10 additions & 7 deletions ml-agents-envs/mlagents_envs/tests/test_side_channel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import struct
import uuid
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.side_channel import (
SideChannel,
IncomingMessage,
OutgoingMessage,
)
from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel
from mlagents_envs.side_channel.raw_bytes_channel import RawBytesChannel
from mlagents_envs.environment import UnityEnvironment
Expand All @@ -11,14 +14,14 @@ def __init__(self):
self.list_int = []
super().__init__(uuid.UUID("a85ba5c0-4f87-11ea-a517-784f4387d1f7"))

def on_message_received(self, data):
val = struct.unpack_from("<i", data, 0)[0]
def on_message_received(self, msg: IncomingMessage) -> None:
val = msg.read_int32()
self.list_int += [val]

def send_int(self, value):
data = bytearray()
data += struct.pack("<i", value)
super().queue_message_to_send(data)
msg = OutgoingMessage()
msg.write_int32(value)
super().queue_message_to_send(msg)


def test_int_channel():
Expand Down