diff --git a/ml-agents/mlagents/trainers/demo_loader.py b/ml-agents/mlagents/trainers/demo_loader.py index ede8078c8c..9941b50ace 100644 --- a/ml-agents/mlagents/trainers/demo_loader.py +++ b/ml-agents/mlagents/trainers/demo_loader.py @@ -1,4 +1,3 @@ -import pathlib import logging import os from typing import List, Tuple @@ -91,6 +90,33 @@ def demo_to_buffer( return brain_params, demo_buffer +def get_demo_files(path: str) -> List[str]: + """ + Retrieves the demonstration file(s) from a path. + :param path: Path of demonstration file or directory. + :return: List of demonstration files + + Raises errors if |path| is invalid. + """ + if os.path.isfile(path): + if not path.endswith(".demo"): + raise ValueError("The path provided is not a '.demo' file.") + return [path] + elif os.path.isdir(path): + paths = [ + os.path.join(path, name) + for name in os.listdir(path) + if name.endswith(".demo") + ] + if not paths: + raise ValueError("There are no '.demo' files in the provided directory.") + return paths + else: + raise FileNotFoundError( + f"The demonstration file or directory {path} does not exist." + ) + + @timed def load_demonstration( file_path: str @@ -103,27 +129,7 @@ def load_demonstration( # First 32 bytes of file dedicated to meta-data. INITIAL_POS = 33 - file_paths = [] - if os.path.isdir(file_path): - all_files = os.listdir(file_path) - for _file in all_files: - if _file.endswith(".demo"): - file_paths.append(os.path.join(file_path, _file)) - if not all_files: - raise ValueError("There are no '.demo' files in the provided directory.") - elif os.path.isfile(file_path): - file_paths.append(file_path) - file_extension = pathlib.Path(file_path).suffix - if file_extension != ".demo": - raise ValueError( - "The file is not a '.demo' file. Please provide a file with the " - "correct extension." - ) - else: - raise FileNotFoundError( - "The demonstration file or directory {} does not exist.".format(file_path) - ) - + file_paths = get_demo_files(file_path) group_spec = None brain_param_proto = None info_action_pairs = [] diff --git a/ml-agents/mlagents/trainers/tests/test_demo_loader.py b/ml-agents/mlagents/trainers/tests/test_demo_loader.py index 3dead40304..475e526985 100644 --- a/ml-agents/mlagents/trainers/tests/test_demo_loader.py +++ b/ml-agents/mlagents/trainers/tests/test_demo_loader.py @@ -1,7 +1,13 @@ import os import numpy as np +import pytest +import tempfile -from mlagents.trainers.demo_loader import load_demonstration, demo_to_buffer +from mlagents.trainers.demo_loader import ( + load_demonstration, + demo_to_buffer, + get_demo_files, +) def test_load_demo(): @@ -26,3 +32,32 @@ def test_load_demo_dir(): _, demo_buffer = demo_to_buffer(path_prefix + "/test_demo_dir", 1) assert len(demo_buffer["actions"]) == total_expected - 1 + + +def test_edge_cases(): + path_prefix = os.path.dirname(os.path.abspath(__file__)) + # nonexistent file and directory + with pytest.raises(FileNotFoundError): + get_demo_files(os.path.join(path_prefix, "nonexistent_file.demo")) + with pytest.raises(FileNotFoundError): + get_demo_files(os.path.join(path_prefix, "nonexistent_directory")) + with tempfile.TemporaryDirectory() as tmpdirname: + # empty directory + with pytest.raises(ValueError): + get_demo_files(tmpdirname) + # invalid file + invalid_fname = os.path.join(tmpdirname, "mydemo.notademo") + with open(invalid_fname, "w") as f: + f.write("I'm not a demo") + with pytest.raises(ValueError): + get_demo_files(invalid_fname) + # invalid directory + with pytest.raises(ValueError): + get_demo_files(tmpdirname) + # valid file + valid_fname = os.path.join(tmpdirname, "mydemo.demo") + with open(valid_fname, "w") as f: + f.write("I'm a demo file") + assert get_demo_files(valid_fname) == [valid_fname] + # valid directory + assert get_demo_files(tmpdirname) == [valid_fname]