diff --git a/ml-agents/mlagents/trainers/demo_loader.py b/ml-agents/mlagents/trainers/demo_loader.py index 5bb5485db2..5b927f7065 100644 --- a/ml-agents/mlagents/trainers/demo_loader.py +++ b/ml-agents/mlagents/trainers/demo_loader.py @@ -1,4 +1,3 @@ -import pathlib import logging import os from typing import List, Tuple @@ -28,16 +27,12 @@ def make_demo_buffer( # Create and populate buffer using experiences demo_raw_buffer = AgentBuffer() demo_processed_buffer = AgentBuffer() - for idx, current_pair_info in enumerate(pair_infos): - if idx > len(pair_infos) - 2: - break - next_pair_info = pair_infos[idx + 1] - current_brain_info = BrainInfo.from_agent_proto( - 0, [current_pair_info.agent_info], brain_params - ) - next_brain_info = BrainInfo.from_agent_proto( - 0, [next_pair_info.agent_info], brain_params - ) + brain_infos = [ + BrainInfo.from_agent_proto(0, [pair_info.agent_info], brain_params) + for pair_info in pair_infos + ] + for idx in range(len(brain_infos) - 1): + current_brain_info, next_brain_info = brain_infos[idx : idx + 2] previous_action = ( np.array(pair_infos[idx].action_info.vector_actions, dtype=np.float32) * 0 ) @@ -55,7 +50,7 @@ def make_demo_buffer( demo_raw_buffer["vector_obs"].append( current_brain_info.vector_observations[0] ) - demo_raw_buffer["actions"].append(current_pair_info.action_info.vector_actions) + demo_raw_buffer["actions"].append(pair_infos[idx].action_info.vector_actions) demo_raw_buffer["prev_action"].append(previous_action) if next_brain_info.local_done[0]: demo_raw_buffer.resequence_and_append( @@ -83,39 +78,43 @@ def demo_to_buffer( return brain_params, demo_buffer -@timed -def load_demonstration( - file_path: str -) -> Tuple[BrainParameters, List[AgentInfoActionPairProto], int]: - """ - Loads and parses a demonstration file. - :param file_path: Location of demonstration file (.demo). - :return: BrainParameter and list of AgentInfoActionPairProto containing demonstration data. +def get_demo_files(path: str) -> List[str]: """ + Retrieves the demonstration file(s) from a path. + :param path: Path of demonstration file or directory. + :return: List of demonstration files - # First 32 bytes of file dedicated to meta-data. - INITIAL_POS = 33 - file_paths = [] - if os.path.isdir(file_path): - all_files = os.listdir(file_path) - for _file in all_files: - if _file.endswith(".demo"): - file_paths.append(os.path.join(file_path, _file)) - if not all_files: + Raises errors if |path| is invalid. + """ + if os.path.isfile(path): + if not path.endswith(".demo"): + raise ValueError("The path provided is not a '.demo' file.") + return [path] + elif os.path.isdir(path): + paths = [ + os.path.join(path, name) + for name in os.listdir(path) + if name.endswith(".demo") + ] + if not paths: raise ValueError("There are no '.demo' files in the provided directory.") - elif os.path.isfile(file_path): - file_paths.append(file_path) - file_extension = pathlib.Path(file_path).suffix - if file_extension != ".demo": - raise ValueError( - "The file is not a '.demo' file. Please provide a file with the " - "correct extension." - ) + return paths else: raise FileNotFoundError( - "The demonstration file or directory {} does not exist.".format(file_path) + f"The demonstration file or directory {path} does not exist." ) + +@timed +def load_demonstration( + path: str, +) -> Tuple[BrainParameters, List[AgentInfoActionPairProto], int]: + """ + Loads and parses a demonstration file or directory. + :param path: File or directory. + :return: BrainParameter and list of AgentInfoActionPairProto containing demonstration data. + """ + file_paths = get_demo_files(path) brain_params = None brain_param_proto = None info_action_pairs = [] @@ -131,12 +130,14 @@ def load_demonstration( meta_data_proto = DemonstrationMetaProto() meta_data_proto.ParseFromString(data[pos : pos + next_pos]) total_expected += meta_data_proto.number_steps + # first 32 bytes of file dedicated to metadata + INITIAL_POS = 33 pos = INITIAL_POS - if obs_decoded == 1: + elif obs_decoded == 1: brain_param_proto = BrainParametersProto() brain_param_proto.ParseFromString(data[pos : pos + next_pos]) pos += next_pos - if obs_decoded > 1: + else: agent_info_action = AgentInfoActionPairProto() agent_info_action.ParseFromString(data[pos : pos + next_pos]) if brain_params is None: @@ -149,7 +150,5 @@ def load_demonstration( pos += next_pos obs_decoded += 1 if not brain_params: - raise RuntimeError( - f"No BrainParameters found in demonstration file at {file_path}." - ) + raise RuntimeError(f"No BrainParameters found in demonstration file at {path}.") return brain_params, info_action_pairs, total_expected diff --git a/ml-agents/mlagents/trainers/tests/test_demo_loader.py b/ml-agents/mlagents/trainers/tests/test_demo_loader.py index 81d454dd14..e0c139018e 100644 --- a/ml-agents/mlagents/trainers/tests/test_demo_loader.py +++ b/ml-agents/mlagents/trainers/tests/test_demo_loader.py @@ -1,10 +1,17 @@ import os +import pytest +import tempfile -from mlagents.trainers.demo_loader import load_demonstration, demo_to_buffer +from mlagents.trainers.demo_loader import ( + load_demonstration, + demo_to_buffer, + get_demo_files, +) + +path_prefix = os.path.dirname(os.path.abspath(__file__)) def test_load_demo(): - path_prefix = os.path.dirname(os.path.abspath(__file__)) brain_parameters, pair_infos, total_expected = load_demonstration( path_prefix + "/test.demo" ) @@ -17,7 +24,6 @@ def test_load_demo(): def test_load_demo_dir(): - path_prefix = os.path.dirname(os.path.abspath(__file__)) brain_parameters, pair_infos, total_expected = load_demonstration( path_prefix + "/test_demo_dir" ) @@ -27,3 +33,28 @@ def test_load_demo_dir(): _, demo_buffer = demo_to_buffer(path_prefix + "/test_demo_dir", 1) assert len(demo_buffer["actions"]) == total_expected - 1 + + +def test_edge_cases(): + # nonexistent file and directory + with pytest.raises(FileNotFoundError): + get_demo_files(os.path.join(path_prefix, "nonexistent_file.demo")) + with pytest.raises(FileNotFoundError): + get_demo_files(os.path.join(path_prefix, "nonexistent_directory")) + with tempfile.TemporaryDirectory() as tmpdirname: + # empty directory + with pytest.raises(ValueError): + get_demo_files(tmpdirname) + # invalid file + invalid_fname = tmpdirname + "/mydemo.notademo" + with open(invalid_fname, "w") as f: + f.write("I'm not a demo") + with pytest.raises(ValueError): + get_demo_files(invalid_fname) + # valid file + valid_fname = tmpdirname + "/mydemo.demo" + with open(valid_fname, "w") as f: + f.write("I'm a demo file") + assert get_demo_files(valid_fname) == [valid_fname] + # valid directory + assert get_demo_files(tmpdirname) == [valid_fname]