4
4
from __future__ import absolute_import
5
5
from builtins import * # NOQA
6
6
from future import standard_library
7
+ from future .utils import with_metaclass
7
8
standard_library .install_aliases ()
8
9
10
+ from abc import ABCMeta
11
+ from abc import abstractmethod
12
+ from abc import abstractproperty
13
+
9
14
import numpy as np
10
15
import six .moves .cPickle as pickle
11
16
14
19
from chainerrl .misc .prioritized import PrioritizedBuffer
15
20
16
21
17
- class ReplayBuffer (object ):
22
+ class AbstractReplayBuffer (with_metaclass (ABCMeta , object )):
23
+ """Defines a common interface of replay buffer.
18
24
19
- def __init__ (self , capacity = None ):
20
- self .memory = RandomAccessQueue (maxlen = capacity )
25
+ You can append transitions to the replay buffer and later sample from it.
26
+ Replay buffers are typically used in experience replay.
27
+ """
21
28
29
+ @abstractmethod
22
30
def append (self , state , action , reward , next_state = None , next_action = None ,
23
31
is_state_terminal = False ):
24
- """Append a transition to this replay buffer
32
+ """Append a transition to this replay buffer.
25
33
26
34
Args:
27
35
state: s_t
@@ -31,13 +39,107 @@ def append(self, state, action, reward, next_state=None, next_action=None,
31
39
next_action: a_{t+1} (can be None for off-policy algorithms)
32
40
is_state_terminal (bool)
33
41
"""
42
+ raise NotImplementedError
43
+
44
+ @abstractmethod
45
+ def sample (self , n ):
46
+ """Sample n unique transitions from this replay buffer.
47
+
48
+ Args:
49
+ n (int): Number of transitions to sample.
50
+ Returns:
51
+ Sequence of n sampled transitions.
52
+ """
53
+ raise NotImplementedError
54
+
55
+ @abstractmethod
56
+ def __len__ (self ):
57
+ """Return the number of transitions in the buffer.
58
+
59
+ Returns:
60
+ Number of transitions in the buffer.
61
+ """
62
+ raise NotImplementedError
63
+
64
+ @abstractmethod
65
+ def save (self , filename ):
66
+ """Save the content of the buffer to a file.
67
+
68
+ Args:
69
+ filename (str): Path to a file.
70
+ """
71
+ raise NotImplementedError
72
+
73
+ @abstractmethod
74
+ def load (self , filename ):
75
+ """Load the content of the buffer from a file.
76
+
77
+ Args:
78
+ filename (str): Path to a file.
79
+ """
80
+ raise NotImplementedError
81
+
82
+
83
+ class AbstractEpisodicReplayBuffer (AbstractReplayBuffer ):
84
+ """Defines a common interface of episodic replay buffer.
85
+
86
+ Episodic replay buffers allows you to append and sample episodes.
87
+ """
88
+
89
+ @abstractmethod
90
+ def sample_episodes (self , n_episodes , max_len = None ):
91
+ """Sample n unique (sub)episodes from this replay buffer.
92
+
93
+ Args:
94
+ n (int): Number of episodes to sample.
95
+ max_len (int or None): Maximum length of sampled episodes. If it is
96
+ smaller than the length of some episode, the subsequence of the
97
+ episode is sampled instead. If None, full episodes are always
98
+ returned.
99
+ Returns:
100
+ Sequence of n sampled epiosodes, each of which is a sequence of
101
+ transitions.
102
+ """
103
+ raise NotImplementedError
104
+
105
+ @abstractproperty
106
+ def n_episodes (self ):
107
+ """Returns the number of episodes in the buffer.
108
+
109
+ Returns:
110
+ Number of episodes in the buffer.
111
+ """
112
+ raise NotImplementedError
113
+
114
+ @abstractmethod
115
+ def stop_current_episode (self ):
116
+ """Notify the buffer that the current episode is interrupted.
117
+
118
+ You may want to interrupt the current episode and start a new one
119
+ before observing a terminal state. This is typical in continuing envs.
120
+ In such cases, you need to call this method before appending a new
121
+ transition so that the buffer will treat it as an initial transition of
122
+ a new episode.
123
+
124
+ This method should not be called after an episode whose termination is
125
+ already notified by appending a transition with is_state_terminal=True.
126
+ """
127
+ raise NotImplementedError
128
+
129
+
130
+ class ReplayBuffer (AbstractReplayBuffer ):
131
+
132
+ def __init__ (self , capacity = None ):
133
+ self .memory = RandomAccessQueue (maxlen = capacity )
134
+
135
+ def append (self , state , action , reward , next_state = None , next_action = None ,
136
+ is_state_terminal = False ):
34
137
experience = dict (state = state , action = action , reward = reward ,
35
138
next_state = next_state , next_action = next_action ,
36
139
is_state_terminal = is_state_terminal )
37
140
self .memory .append (experience )
38
141
39
142
def sample (self , n ):
40
- """Sample n unique samples from this replay buffer"""
41
143
assert len (self .memory ) >= n
42
144
return self .memory .sample (n )
43
145
@@ -117,7 +219,6 @@ def __init__(self, capacity=None,
117
219
self , alpha , beta0 , betasteps , eps , normalize_by_max )
118
220
119
221
def sample (self , n ):
120
- """Sample n unique samples from this replay buffer"""
121
222
assert len (self .memory ) >= n
122
223
sampled , probabilities = self .memory .sample (n )
123
224
weights = self .weights_from_probabilities (probabilities )
@@ -137,7 +238,7 @@ def random_subseq(seq, subseq_len):
137
238
return seq [i :i + subseq_len ]
138
239
139
240
140
- class EpisodicReplayBuffer (object ):
241
+ class EpisodicReplayBuffer (AbstractEpisodicReplayBuffer ):
141
242
142
243
def __init__ (self , capacity = None ):
143
244
self .current_episode = []
@@ -147,16 +248,6 @@ def __init__(self, capacity=None):
147
248
148
249
def append (self , state , action , reward , next_state = None , next_action = None ,
149
250
is_state_terminal = False , ** kwargs ):
150
- """Append a transition to this replay buffer
151
-
152
- Args:
153
- state: s_t
154
- action: a_t
155
- reward: r_t
156
- next_state: s_{t+1} (can be None if terminal)
157
- next_action: a_{t+1} (can be None for off-policy algorithms)
158
- is_state_terminal (bool)
159
- """
160
251
experience = dict (state = state , action = action , reward = reward ,
161
252
next_state = next_state , next_action = next_action ,
162
253
is_state_terminal = is_state_terminal ,
@@ -166,12 +257,10 @@ def append(self, state, action, reward, next_state=None, next_action=None,
166
257
self .stop_current_episode ()
167
258
168
259
def sample (self , n ):
169
- """Sample n unique samples from this replay buffer"""
170
260
assert len (self .memory ) >= n
171
261
return self .memory .sample (n )
172
262
173
263
def sample_episodes (self , n_episodes , max_len = None ):
174
- """Sample n unique samples from this replay buffer"""
175
264
assert len (self .episodic_memory ) >= n_episodes
176
265
episodes = self .episodic_memory .sample (n_episodes )
177
266
if max_len is not None :
@@ -180,6 +269,10 @@ def sample_episodes(self, n_episodes, max_len=None):
180
269
return episodes
181
270
182
271
def __len__ (self ):
272
+ return len (self .memory )
273
+
274
+ @property
275
+ def n_episodes (self ):
183
276
return len (self .episodic_memory )
184
277
185
278
def save (self , filename ):
@@ -313,6 +406,11 @@ def __init__(self, replay_buffer, update_func, batchsize, episodic_update,
313
406
def update_if_necessary (self , iteration ):
314
407
if len (self .replay_buffer ) < self .replay_start_size :
315
408
return
409
+
410
+ if (self .episodic_update
411
+ and self .replay_buffer .n_episodes < self .batchsize ):
412
+ return
413
+
316
414
if iteration % self .update_interval != 0 :
317
415
return
318
416
0 commit comments