Skip to content

Commit 940ae01

Browse files
authored
Merge pull request #455 from muupan/fix-same-process-idx
Fix a bug of unintentionally using same process indices
2 parents a63eb14 + 3b3c16d commit 940ae01

File tree

6 files changed

+14
-8
lines changed

6 files changed

+14
-8
lines changed

examples/ale/train_a2c_ale.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
standard_library.install_aliases() # NOQA
88

99
import argparse
10+
import functools
1011
import logging
1112

1213
import chainer
@@ -123,7 +124,7 @@ def make_env(process_idx, test):
123124

124125
def make_batch_env(test):
125126
return chainerrl.envs.MultiprocessVectorEnv(
126-
[(lambda: make_env(idx, test))
127+
[functools.partial(make_env, idx, test)
127128
for idx, env in enumerate(range(args.num_envs))])
128129

129130
sample_env = make_env(0, test=False)

examples/ale/train_dqn_batch_ale.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from future import standard_library
77
standard_library.install_aliases() # NOQA
88
import argparse
9+
import functools
910
import os
1011

1112
import chainer
@@ -161,7 +162,7 @@ def make_env(idx, test):
161162

162163
def make_batch_env(test):
163164
vec_env = chainerrl.envs.MultiprocessVectorEnv(
164-
[(lambda: make_env(idx, test))
165+
[functools.partial(make_env, idx, test)
165166
for idx, env in enumerate(range(args.num_envs))])
166167
vec_env = chainerrl.wrappers.VectorFrameStack(vec_env, 4)
167168
return vec_env

examples/gym/train_a2c_gym.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from future import standard_library
1515
standard_library.install_aliases() # NOQA
1616
import argparse
17+
import functools
1718

1819
import chainer
1920
from chainer import functions as F
@@ -153,7 +154,7 @@ def make_env(process_idx, test):
153154

154155
def make_batch_env(test):
155156
return chainerrl.envs.MultiprocessVectorEnv(
156-
[(lambda: make_env(idx, test))
157+
[functools.partial(make_env, idx, test)
157158
for idx, env in enumerate(range(args.num_envs))])
158159

159160
sample_env = make_env(process_idx=0, test=False)

examples/gym/train_ddpg_batch_gym.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from future import standard_library
66
standard_library.install_aliases() # NOQA
77
import argparse
8+
import functools
89
import sys
910

1011
import chainer
@@ -106,7 +107,7 @@ def make_env(idx, test):
106107

107108
def make_batch_env(test):
108109
return chainerrl.envs.MultiprocessVectorEnv(
109-
[(lambda: make_env(idx, test))
110+
[functools.partial(make_env, idx, test)
110111
for idx, env in enumerate(range(args.num_envs))])
111112

112113
sample_env = make_env(0, test=False)

examples/gym/train_ppo_batch_gym.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from future import standard_library
1515
standard_library.install_aliases() # NOQA
1616
import argparse
17+
import functools
1718

1819
import chainer
1920
from chainer import functions as F
@@ -96,7 +97,7 @@ def make_env(process_idx, test):
9697

9798
def make_batch_env(test):
9899
return chainerrl.envs.MultiprocessVectorEnv(
99-
[(lambda: make_env(idx, test))
100+
[functools.partial(make_env, idx, test)
100101
for idx, env in enumerate(range(args.num_envs))])
101102

102103
# Only for getting timesteps, and obs-action spaces

tests/wrappers_tests/test_vector_frame_stack.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from future import standard_library
77
standard_library.install_aliases() # NOQA
88

9+
import functools
910
import mock
1011
import unittest
1112

@@ -68,14 +69,14 @@ def make_env(idx):
6869

6970
# Wrap by FrameStack and MultiprocessVectorEnv
7071
fs_env = chainerrl.envs.MultiprocessVectorEnv(
71-
[(lambda: FrameStack(
72-
make_env(idx), k=self.k, channel_order='chw'))
72+
[functools.partial(
73+
FrameStack, make_env(idx), k=self.k, channel_order='chw')
7374
for idx, env in enumerate(range(self.num_envs))])
7475

7576
# Wrap by MultiprocessVectorEnv and VectorFrameStack
7677
vfs_env = VectorFrameStack(
7778
chainerrl.envs.MultiprocessVectorEnv(
78-
[(lambda: make_env(idx))
79+
[functools.partial(make_env, idx)
7980
for idx, env in enumerate(range(self.num_envs))]),
8081
k=self.k, stack_axis=0)
8182

0 commit comments

Comments
 (0)