Skip to content

Commit 0b16a7e

Browse files
kevtanChris Elion
authored andcommitted
Refactor file logic in demo_loader and add unit tests. (#3241)
1 parent ea1b435 commit 0b16a7e

File tree

2 files changed

+64
-23
lines changed

2 files changed

+64
-23
lines changed

ml-agents/mlagents/trainers/demo_loader.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pathlib
21
import logging
32
import os
43
from typing import List, Tuple
@@ -91,6 +90,33 @@ def demo_to_buffer(
9190
return brain_params, demo_buffer
9291

9392

93+
def get_demo_files(path: str) -> List[str]:
94+
"""
95+
Retrieves the demonstration file(s) from a path.
96+
:param path: Path of demonstration file or directory.
97+
:return: List of demonstration files
98+
99+
Raises errors if |path| is invalid.
100+
"""
101+
if os.path.isfile(path):
102+
if not path.endswith(".demo"):
103+
raise ValueError("The path provided is not a '.demo' file.")
104+
return [path]
105+
elif os.path.isdir(path):
106+
paths = [
107+
os.path.join(path, name)
108+
for name in os.listdir(path)
109+
if name.endswith(".demo")
110+
]
111+
if not paths:
112+
raise ValueError("There are no '.demo' files in the provided directory.")
113+
return paths
114+
else:
115+
raise FileNotFoundError(
116+
f"The demonstration file or directory {path} does not exist."
117+
)
118+
119+
94120
@timed
95121
def load_demonstration(
96122
file_path: str
@@ -103,27 +129,7 @@ def load_demonstration(
103129

104130
# First 32 bytes of file dedicated to meta-data.
105131
INITIAL_POS = 33
106-
file_paths = []
107-
if os.path.isdir(file_path):
108-
all_files = os.listdir(file_path)
109-
for _file in all_files:
110-
if _file.endswith(".demo"):
111-
file_paths.append(os.path.join(file_path, _file))
112-
if not all_files:
113-
raise ValueError("There are no '.demo' files in the provided directory.")
114-
elif os.path.isfile(file_path):
115-
file_paths.append(file_path)
116-
file_extension = pathlib.Path(file_path).suffix
117-
if file_extension != ".demo":
118-
raise ValueError(
119-
"The file is not a '.demo' file. Please provide a file with the "
120-
"correct extension."
121-
)
122-
else:
123-
raise FileNotFoundError(
124-
"The demonstration file or directory {} does not exist.".format(file_path)
125-
)
126-
132+
file_paths = get_demo_files(file_path)
127133
group_spec = None
128134
brain_param_proto = None
129135
info_action_pairs = []

ml-agents/mlagents/trainers/tests/test_demo_loader.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import os
22
import numpy as np
3+
import pytest
4+
import tempfile
35

4-
from mlagents.trainers.demo_loader import load_demonstration, demo_to_buffer
6+
from mlagents.trainers.demo_loader import (
7+
load_demonstration,
8+
demo_to_buffer,
9+
get_demo_files,
10+
)
511

612

713
def test_load_demo():
@@ -26,3 +32,32 @@ def test_load_demo_dir():
2632

2733
_, demo_buffer = demo_to_buffer(path_prefix + "/test_demo_dir", 1)
2834
assert len(demo_buffer["actions"]) == total_expected - 1
35+
36+
37+
def test_edge_cases():
38+
path_prefix = os.path.dirname(os.path.abspath(__file__))
39+
# nonexistent file and directory
40+
with pytest.raises(FileNotFoundError):
41+
get_demo_files(os.path.join(path_prefix, "nonexistent_file.demo"))
42+
with pytest.raises(FileNotFoundError):
43+
get_demo_files(os.path.join(path_prefix, "nonexistent_directory"))
44+
with tempfile.TemporaryDirectory() as tmpdirname:
45+
# empty directory
46+
with pytest.raises(ValueError):
47+
get_demo_files(tmpdirname)
48+
# invalid file
49+
invalid_fname = os.path.join(tmpdirname, "mydemo.notademo")
50+
with open(invalid_fname, "w") as f:
51+
f.write("I'm not a demo")
52+
with pytest.raises(ValueError):
53+
get_demo_files(invalid_fname)
54+
# invalid directory
55+
with pytest.raises(ValueError):
56+
get_demo_files(tmpdirname)
57+
# valid file
58+
valid_fname = os.path.join(tmpdirname, "mydemo.demo")
59+
with open(valid_fname, "w") as f:
60+
f.write("I'm a demo file")
61+
assert get_demo_files(valid_fname) == [valid_fname]
62+
# valid directory
63+
assert get_demo_files(tmpdirname) == [valid_fname]

0 commit comments

Comments
 (0)