Skip to content

Commit 51a2762

Browse files
authored
Merge pull request #155 from muupan/fix-episodic-buffer-len
Fix episodic buffer __len__
2 parents 52d2af6 + 3dc1ce5 commit 51a2762

File tree

4 files changed

+140
-32
lines changed

4 files changed

+140
-32
lines changed

chainerrl/agents/pcl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ def update_from_replay(self):
273273
if len(self.replay_buffer) < self.replay_start_size:
274274
return
275275

276+
if self.replay_buffer.n_episodes < self.batchsize:
277+
return
278+
276279
if self.process_idx == 0:
277280
self.logger.debug('update_from_replay')
278281

chainerrl/replay_buffer.py

Lines changed: 117 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,13 @@
44
from __future__ import absolute_import
55
from builtins import * # NOQA
66
from future import standard_library
7+
from future.utils import with_metaclass
78
standard_library.install_aliases()
89

10+
from abc import ABCMeta
11+
from abc import abstractmethod
12+
from abc import abstractproperty
13+
914
import numpy as np
1015
import six.moves.cPickle as pickle
1116

@@ -14,14 +19,17 @@
1419
from chainerrl.misc.prioritized import PrioritizedBuffer
1520

1621

17-
class ReplayBuffer(object):
22+
class AbstractReplayBuffer(with_metaclass(ABCMeta, object)):
23+
"""Defines a common interface of replay buffer.
1824
19-
def __init__(self, capacity=None):
20-
self.memory = RandomAccessQueue(maxlen=capacity)
25+
You can append transitions to the replay buffer and later sample from it.
26+
Replay buffers are typically used in experience replay.
27+
"""
2128

29+
@abstractmethod
2230
def append(self, state, action, reward, next_state=None, next_action=None,
2331
is_state_terminal=False):
24-
"""Append a transition to this replay buffer
32+
"""Append a transition to this replay buffer.
2533
2634
Args:
2735
state: s_t
@@ -31,13 +39,107 @@ def append(self, state, action, reward, next_state=None, next_action=None,
3139
next_action: a_{t+1} (can be None for off-policy algorithms)
3240
is_state_terminal (bool)
3341
"""
42+
raise NotImplementedError
43+
44+
@abstractmethod
45+
def sample(self, n):
46+
"""Sample n unique transitions from this replay buffer.
47+
48+
Args:
49+
n (int): Number of transitions to sample.
50+
Returns:
51+
Sequence of n sampled transitions.
52+
"""
53+
raise NotImplementedError
54+
55+
@abstractmethod
56+
def __len__(self):
57+
"""Return the number of transitions in the buffer.
58+
59+
Returns:
60+
Number of transitions in the buffer.
61+
"""
62+
raise NotImplementedError
63+
64+
@abstractmethod
65+
def save(self, filename):
66+
"""Save the content of the buffer to a file.
67+
68+
Args:
69+
filename (str): Path to a file.
70+
"""
71+
raise NotImplementedError
72+
73+
@abstractmethod
74+
def load(self, filename):
75+
"""Load the content of the buffer from a file.
76+
77+
Args:
78+
filename (str): Path to a file.
79+
"""
80+
raise NotImplementedError
81+
82+
83+
class AbstractEpisodicReplayBuffer(AbstractReplayBuffer):
84+
"""Defines a common interface of episodic replay buffer.
85+
86+
Episodic replay buffers allows you to append and sample episodes.
87+
"""
88+
89+
@abstractmethod
90+
def sample_episodes(self, n_episodes, max_len=None):
91+
"""Sample n unique (sub)episodes from this replay buffer.
92+
93+
Args:
94+
n (int): Number of episodes to sample.
95+
max_len (int or None): Maximum length of sampled episodes. If it is
96+
smaller than the length of some episode, the subsequence of the
97+
episode is sampled instead. If None, full episodes are always
98+
returned.
99+
Returns:
100+
Sequence of n sampled epiosodes, each of which is a sequence of
101+
transitions.
102+
"""
103+
raise NotImplementedError
104+
105+
@abstractproperty
106+
def n_episodes(self):
107+
"""Returns the number of episodes in the buffer.
108+
109+
Returns:
110+
Number of episodes in the buffer.
111+
"""
112+
raise NotImplementedError
113+
114+
@abstractmethod
115+
def stop_current_episode(self):
116+
"""Notify the buffer that the current episode is interrupted.
117+
118+
You may want to interrupt the current episode and start a new one
119+
before observing a terminal state. This is typical in continuing envs.
120+
In such cases, you need to call this method before appending a new
121+
transition so that the buffer will treat it as an initial transition of
122+
a new episode.
123+
124+
This method should not be called after an episode whose termination is
125+
already notified by appending a transition with is_state_terminal=True.
126+
"""
127+
raise NotImplementedError
128+
129+
130+
class ReplayBuffer(AbstractReplayBuffer):
131+
132+
def __init__(self, capacity=None):
133+
self.memory = RandomAccessQueue(maxlen=capacity)
134+
135+
def append(self, state, action, reward, next_state=None, next_action=None,
136+
is_state_terminal=False):
34137
experience = dict(state=state, action=action, reward=reward,
35138
next_state=next_state, next_action=next_action,
36139
is_state_terminal=is_state_terminal)
37140
self.memory.append(experience)
38141

39142
def sample(self, n):
40-
"""Sample n unique samples from this replay buffer"""
41143
assert len(self.memory) >= n
42144
return self.memory.sample(n)
43145

@@ -117,7 +219,6 @@ def __init__(self, capacity=None,
117219
self, alpha, beta0, betasteps, eps, normalize_by_max)
118220

119221
def sample(self, n):
120-
"""Sample n unique samples from this replay buffer"""
121222
assert len(self.memory) >= n
122223
sampled, probabilities = self.memory.sample(n)
123224
weights = self.weights_from_probabilities(probabilities)
@@ -137,7 +238,7 @@ def random_subseq(seq, subseq_len):
137238
return seq[i:i + subseq_len]
138239

139240

140-
class EpisodicReplayBuffer(object):
241+
class EpisodicReplayBuffer(AbstractEpisodicReplayBuffer):
141242

142243
def __init__(self, capacity=None):
143244
self.current_episode = []
@@ -147,16 +248,6 @@ def __init__(self, capacity=None):
147248

148249
def append(self, state, action, reward, next_state=None, next_action=None,
149250
is_state_terminal=False, **kwargs):
150-
"""Append a transition to this replay buffer
151-
152-
Args:
153-
state: s_t
154-
action: a_t
155-
reward: r_t
156-
next_state: s_{t+1} (can be None if terminal)
157-
next_action: a_{t+1} (can be None for off-policy algorithms)
158-
is_state_terminal (bool)
159-
"""
160251
experience = dict(state=state, action=action, reward=reward,
161252
next_state=next_state, next_action=next_action,
162253
is_state_terminal=is_state_terminal,
@@ -166,12 +257,10 @@ def append(self, state, action, reward, next_state=None, next_action=None,
166257
self.stop_current_episode()
167258

168259
def sample(self, n):
169-
"""Sample n unique samples from this replay buffer"""
170260
assert len(self.memory) >= n
171261
return self.memory.sample(n)
172262

173263
def sample_episodes(self, n_episodes, max_len=None):
174-
"""Sample n unique samples from this replay buffer"""
175264
assert len(self.episodic_memory) >= n_episodes
176265
episodes = self.episodic_memory.sample(n_episodes)
177266
if max_len is not None:
@@ -180,6 +269,10 @@ def sample_episodes(self, n_episodes, max_len=None):
180269
return episodes
181270

182271
def __len__(self):
272+
return len(self.memory)
273+
274+
@property
275+
def n_episodes(self):
183276
return len(self.episodic_memory)
184277

185278
def save(self, filename):
@@ -313,6 +406,11 @@ def __init__(self, replay_buffer, update_func, batchsize, episodic_update,
313406
def update_if_necessary(self, iteration):
314407
if len(self.replay_buffer) < self.replay_start_size:
315408
return
409+
410+
if (self.episodic_update
411+
and self.replay_buffer.n_episodes < self.batchsize):
412+
return
413+
316414
if iteration % self.update_interval != 0:
317415
return
318416

examples/gym/train_dqn_gym.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def main():
5454
parser.add_argument('--steps', type=int, default=10 ** 5)
5555
parser.add_argument('--prioritized-replay', action='store_true')
5656
parser.add_argument('--episodic-replay', action='store_true')
57-
parser.add_argument('--replay-start-size', type=int, default=None)
57+
parser.add_argument('--replay-start-size', type=int, default=1000)
5858
parser.add_argument('--target-update-interval', type=int, default=10 ** 2)
5959
parser.add_argument('--target-update-method', type=str, default='hard')
6060
parser.add_argument('--soft-update-tau', type=float, default=1e-2)
@@ -130,11 +130,8 @@ def make_env(for_eval):
130130
if args.episodic_replay:
131131
if args.minibatch_size is None:
132132
args.minibatch_size = 4
133-
if args.replay_start_size is None:
134-
args.replay_start_size = 10
135133
if args.prioritized_replay:
136-
betasteps = \
137-
(args.steps - timestep_limit * args.replay_start_size) \
134+
betasteps = (args.steps - args.replay_start_size) \
138135
// args.update_interval
139136
rbuf = replay_buffer.PrioritizedEpisodicReplayBuffer(
140137
rbuf_capacity, betasteps=betasteps)
@@ -143,8 +140,6 @@ def make_env(for_eval):
143140
else:
144141
if args.minibatch_size is None:
145142
args.minibatch_size = 32
146-
if args.replay_start_size is None:
147-
args.replay_start_size = 1000
148143
if args.prioritized_replay:
149144
betasteps = (args.steps - args.replay_start_size) \
150145
// args.update_interval

tests/test_replay_buffer.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def subtest_append_and_sample(self, capacity):
113113
for trans in transs:
114114
rbuf.append(**trans)
115115

116+
self.assertEqual(len(rbuf), 90)
117+
self.assertEqual(rbuf.n_episodes, 9)
118+
116119
for k in [10, 30, 90]:
117120
s = rbuf.sample(k)
118121
self.assertEqual(len(s), k)
@@ -130,13 +133,13 @@ def subtest_append_and_sample(self, capacity):
130133

131134
def test_save_and_load(self):
132135
for capacity in [100, None]:
133-
self.subtest_append_and_sample(capacity)
136+
self.subtest_save_and_load(capacity)
134137

135138
def subtest_save_and_load(self, capacity):
136139

137140
tempdir = tempfile.mkdtemp()
138141

139-
rbuf = replay_buffer.ReplayBuffer(capacity)
142+
rbuf = replay_buffer.EpisodicReplayBuffer(capacity)
140143

141144
transs = [dict(state=n, action=n+10, reward=n+20,
142145
next_state=n+1, next_action=n+11,
@@ -153,12 +156,15 @@ def subtest_save_and_load(self, capacity):
153156
rbuf.append(**transs[4])
154157
rbuf.stop_current_episode()
155158

159+
self.assertEqual(len(rbuf), 5)
160+
self.assertEqual(rbuf.n_episodes, 2)
161+
156162
# Save
157163
filename = os.path.join(tempdir, 'rbuf.pkl')
158164
rbuf.save(filename)
159165

160166
# Initialize rbuf
161-
rbuf = replay_buffer.ReplayBuffer(capacity)
167+
rbuf = replay_buffer.EpisodicReplayBuffer(capacity)
162168

163169
# Of course it has no transition yet
164170
self.assertEqual(len(rbuf), 0)
@@ -168,22 +174,26 @@ def subtest_save_and_load(self, capacity):
168174

169175
# Sampled transitions are exactly what I added!
170176
s5 = rbuf.sample(5)
171-
self.assertEqual(len(s5) == 5)
177+
self.assertEqual(len(s5), 5)
172178
for t in s5:
173179
n = t['state']
174180
self.assertIn(n, range(5))
175181
self.assertEqual(t, transs[n])
176182

177183
# And sampled episodes are exactly what I added!
178184
s2e = rbuf.sample_episodes(2)
179-
self.assertEqual(len(s2e) == 2)
185+
self.assertEqual(len(s2e), 2)
180186
if s2e[0][0]['state'] == 0:
181187
self.assertEqual(s2e[0], [transs[0], transs[1]])
182188
self.assertEqual(s2e[1], [transs[2], transs[3], transs[4]])
183189
else:
184190
self.assertEqual(s2e[0], [transs[2], transs[3], transs[4]])
185191
self.assertEqual(s2e[1], [transs[0], transs[1]])
186192

193+
# Sizes are correct!
194+
self.assertEqual(len(rbuf), 5)
195+
self.assertEqual(rbuf.n_episodes, 2)
196+
187197

188198
class TestPrioritizedReplayBuffer(unittest.TestCase):
189199

@@ -354,7 +364,9 @@ def test_append_and_sample(self):
354364
for i in range(n)]
355365
for trans in transs:
356366
rbuf.append(**trans)
357-
self.assertEqual(len(rbuf), 9)
367+
368+
self.assertEqual(len(rbuf), 90)
369+
self.assertEqual(rbuf.n_episodes, 9)
358370

359371
for k in [10, 30, 90]:
360372
s = rbuf.sample(k)

0 commit comments

Comments
 (0)