Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions ml-agents/mlagents/trainers/demo_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pathlib
import logging
import os
from typing import List, Tuple
Expand Down Expand Up @@ -91,6 +90,33 @@ def demo_to_buffer(
return brain_params, demo_buffer


def get_demo_files(path: str) -> List[str]:
"""
Retrieves the demonstration file(s) from a path.
:param path: Path of demonstration file or directory.
:return: List of demonstration files

Raises errors if |path| is invalid.
"""
if os.path.isfile(path):
if not path.endswith(".demo"):
raise ValueError("The path provided is not a '.demo' file.")
return [path]
elif os.path.isdir(path):
paths = [
os.path.join(path, name)
for name in os.listdir(path)
if name.endswith(".demo")
]
if not paths:
raise ValueError("There are no '.demo' files in the provided directory.")
return paths
else:
raise FileNotFoundError(
f"The demonstration file or directory {path} does not exist."
)


@timed
def load_demonstration(
file_path: str
Expand All @@ -103,27 +129,7 @@ def load_demonstration(

# First 32 bytes of file dedicated to meta-data.
INITIAL_POS = 33
file_paths = []
if os.path.isdir(file_path):
all_files = os.listdir(file_path)
for _file in all_files:
if _file.endswith(".demo"):
file_paths.append(os.path.join(file_path, _file))
if not all_files:
raise ValueError("There are no '.demo' files in the provided directory.")
elif os.path.isfile(file_path):
file_paths.append(file_path)
file_extension = pathlib.Path(file_path).suffix
if file_extension != ".demo":
raise ValueError(
"The file is not a '.demo' file. Please provide a file with the "
"correct extension."
)
else:
raise FileNotFoundError(
"The demonstration file or directory {} does not exist.".format(file_path)
)

file_paths = get_demo_files(file_path)
group_spec = None
brain_param_proto = None
info_action_pairs = []
Expand Down
37 changes: 36 additions & 1 deletion ml-agents/mlagents/trainers/tests/test_demo_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import os
import numpy as np
import pytest
import tempfile

from mlagents.trainers.demo_loader import load_demonstration, demo_to_buffer
from mlagents.trainers.demo_loader import (
load_demonstration,
demo_to_buffer,
get_demo_files,
)


def test_load_demo():
Expand All @@ -26,3 +32,32 @@ def test_load_demo_dir():

_, demo_buffer = demo_to_buffer(path_prefix + "/test_demo_dir", 1)
assert len(demo_buffer["actions"]) == total_expected - 1


def test_edge_cases():
path_prefix = os.path.dirname(os.path.abspath(__file__))
# nonexistent file and directory
with pytest.raises(FileNotFoundError):
get_demo_files(os.path.join(path_prefix, "nonexistent_file.demo"))
with pytest.raises(FileNotFoundError):
get_demo_files(os.path.join(path_prefix, "nonexistent_directory"))
with tempfile.TemporaryDirectory() as tmpdirname:
# empty directory
with pytest.raises(ValueError):
get_demo_files(tmpdirname)
# invalid file
invalid_fname = os.path.join(tmpdirname, "mydemo.notademo")
with open(invalid_fname, "w") as f:
f.write("I'm not a demo")
with pytest.raises(ValueError):
get_demo_files(invalid_fname)
# invalid directory
with pytest.raises(ValueError):
get_demo_files(tmpdirname)
# valid file
valid_fname = os.path.join(tmpdirname, "mydemo.demo")
with open(valid_fname, "w") as f:
f.write("I'm a demo file")
assert get_demo_files(valid_fname) == [valid_fname]
# valid directory
assert get_demo_files(tmpdirname) == [valid_fname]