Skip to content

Commit 1b1d160

Browse files
Acme Contributorcopybara-github
authored andcommitted
Factor out a run_episode method which returns per-episode logged data.
A common use case is to call this method (instead of run()), so the caller has some flexibility to consume the logged data after every episode. PiperOrigin-RevId: 326652451 Change-Id: Idad1687b736c1e8312f8f28b4fdb66ce290e1f75
1 parent 509ff85 commit 1b1d160

File tree

2 files changed

+76
-50
lines changed

2 files changed

+76
-50
lines changed

acme/environment_loop.py

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,52 @@ def __init__(
6060
self._counter = counter or counting.Counter()
6161
self._logger = logger or loggers.make_default_logger(label)
6262

63+
def run_episode(self) -> loggers.LoggingData:
64+
"""Run one episode.
65+
66+
Each episode is a loop which interacts first with the environment to get an
67+
observation and then give that observation to the agent in order to retrieve
68+
an action.
69+
70+
Returns:
71+
An instance of `loggers.LoggingData`.
72+
"""
73+
# Reset any counts and start the environment.
74+
start_time = time.time()
75+
episode_steps = 0
76+
episode_return = 0
77+
timestep = self._environment.reset()
78+
79+
# Make the first observation.
80+
self._actor.observe_first(timestep)
81+
82+
# Run an episode.
83+
while not timestep.last():
84+
# Generate an action from the agent's policy and step the environment.
85+
action = self._actor.select_action(timestep.observation)
86+
timestep = self._environment.step(action)
87+
88+
# Have the agent observe the timestep and let the actor update itself.
89+
self._actor.observe(action, next_timestep=timestep)
90+
self._actor.update()
91+
92+
# Book-keeping.
93+
episode_steps += 1
94+
episode_return += timestep.reward
95+
96+
# Record counts.
97+
counts = self._counter.increment(episodes=1, steps=episode_steps)
98+
99+
# Collect the results and combine with counts.
100+
steps_per_second = episode_steps / (time.time() - start_time)
101+
result = {
102+
'episode_length': episode_steps,
103+
'episode_return': episode_return,
104+
'steps_per_second': steps_per_second,
105+
}
106+
result.update(counts)
107+
return result
108+
63109
def run(self,
64110
num_episodes: Optional[int] = None,
65111
num_steps: Optional[int] = None):
@@ -69,12 +115,10 @@ def run(self,
69115
least `num_steps` steps (the last episode is always run until completion,
70116
so the total number of steps may be slightly more than `num_steps`).
71117
At least one of these two arguments has to be None.
72-
Each episode is itself a loop which interacts first with the environment to
73-
get an observation and then give that observation to the agent in order to
74-
retrieve an action. Upon termination of an episode a new episode will be
75-
started. If the number of episodes and the number of steps are not given
76-
then this will interact with the environment infinitely.
77-
If both num_episodes and num_steps are `None` (default), runs without limit.
118+
119+
Upon termination of an episode a new episode will be started. If the number
120+
of episodes and the number of steps are not given then this will interact
121+
with the environment infinitely.
78122
79123
Args:
80124
num_episodes: number of episodes to run the loop for.
@@ -93,43 +137,9 @@ def should_terminate(episode_count: int, step_count: int) -> bool:
93137

94138
episode_count, step_count = 0, 0
95139
while not should_terminate(episode_count, step_count):
96-
# Reset any counts and start the environment.
97-
start_time = time.time()
98-
episode_steps = 0
99-
episode_return = 0
100-
timestep = self._environment.reset()
101-
102-
# Make the first observation.
103-
self._actor.observe_first(timestep)
104-
105-
# Run an episode.
106-
while not timestep.last():
107-
# Generate an action from the agent's policy and step the environment.
108-
action = self._actor.select_action(timestep.observation)
109-
timestep = self._environment.step(action)
110-
111-
# Have the agent observe the timestep and let the actor update itself.
112-
self._actor.observe(action, next_timestep=timestep)
113-
self._actor.update()
114-
115-
# Book-keeping.
116-
episode_steps += 1
117-
episode_return += timestep.reward
118-
119-
# Record counts.
120-
counts = self._counter.increment(episodes=1, steps=episode_steps)
121-
122-
# Collect the results and combine with counts.
123-
steps_per_second = episode_steps / (time.time() - start_time)
124-
result = {
125-
'episode_length': episode_steps,
126-
'episode_return': episode_return,
127-
'steps_per_second': steps_per_second,
128-
}
129-
result.update(counts)
140+
result = self.run_episode()
130141
episode_count += 1
131-
step_count += episode_steps
132-
142+
step_count += result['episode_length']
133143
# Log the given results.
134144
self._logger.write(result)
135145

acme/environment_loop_test.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,34 @@
2121
from acme import specs
2222
from acme.testing import fakes
2323

24+
EPISODE_LENGTH = 10
25+
2426

2527
class EnvironmentLoopTest(absltest.TestCase):
2628

27-
def test_environment_loop(self):
29+
def setUp(self):
30+
super().setUp()
2831
# Create the actor/environment and stick them in a loop.
29-
environment = fakes.DiscreteEnvironment(episode_length=10)
30-
actor = fakes.Actor(specs.make_environment_spec(environment))
31-
loop = environment_loop.EnvironmentLoop(environment, actor)
32-
33-
# Run the loop. There should be episode_length+1 update calls per episode.
34-
loop.run(num_episodes=10)
35-
self.assertEqual(actor.num_updates, 100)
32+
environment = fakes.DiscreteEnvironment(episode_length=EPISODE_LENGTH)
33+
self.actor = fakes.Actor(specs.make_environment_spec(environment))
34+
self.loop = environment_loop.EnvironmentLoop(environment, self.actor)
35+
36+
def test_one_episode(self):
37+
result = self.loop.run_episode()
38+
self.assertDictContainsSubset({'episode_length': EPISODE_LENGTH}, result)
39+
self.assertIn('episode_return', result)
40+
self.assertIn('steps_per_second', result)
41+
42+
def test_run_episodes(self):
43+
# Run the loop. There should be EPISODE_LENGTH update calls per episode.
44+
self.loop.run(num_episodes=10)
45+
self.assertEqual(self.actor.num_updates, 10 * EPISODE_LENGTH)
46+
47+
def test_run_steps(self):
48+
# Run the loop. This will run 2 episodes so that total number of steps is
49+
# at least 15.
50+
self.loop.run(num_steps=EPISODE_LENGTH + 5)
51+
self.assertEqual(self.actor.num_updates, 2 * EPISODE_LENGTH)
3652

3753

3854
if __name__ == '__main__':

0 commit comments

Comments
 (0)