Skip to content

Make copy_param support scalar parameters #410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions chainerrl/misc/copy_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ def copy_param(target_link, source_link):
'not initialized.\nPlease try to forward dummy input '
'beforehand to determine parameter shape of the model.'.format(
param_name))
target_params[param_name].array[:] = param.array
target_params[param_name].array[...] = param.array

# Copy Batch Normalization's statistics
target_links = dict(target_link.namedlinks())
for link_name, link in source_link.namedlinks():
if isinstance(link, L.BatchNormalization):
target_bn = target_links[link_name]
target_bn.avg_mean[:] = link.avg_mean
target_bn.avg_var[:] = link.avg_var
target_bn.avg_mean[...] = link.avg_mean
target_bn.avg_var[...] = link.avg_var


def soft_copy_param(target_link, source_link, tau):
Expand All @@ -40,25 +40,25 @@ def soft_copy_param(target_link, source_link, tau):
'not initialized.\nPlease try to forward dummy input '
'beforehand to determine parameter shape of the model.'.format(
param_name))
target_params[param_name].array[:] *= (1 - tau)
target_params[param_name].array[:] += tau * param.array
target_params[param_name].array[...] *= (1 - tau)
target_params[param_name].array[...] += tau * param.array

# Soft-copy Batch Normalization's statistics
target_links = dict(target_link.namedlinks())
for link_name, link in source_link.namedlinks():
if isinstance(link, L.BatchNormalization):
target_bn = target_links[link_name]
target_bn.avg_mean[:] *= (1 - tau)
target_bn.avg_mean[:] += tau * link.avg_mean
target_bn.avg_var[:] *= (1 - tau)
target_bn.avg_var[:] += tau * link.avg_var
target_bn.avg_mean[...] *= (1 - tau)
target_bn.avg_mean[...] += tau * link.avg_mean
target_bn.avg_var[...] *= (1 - tau)
target_bn.avg_var[...] += tau * link.avg_var


def copy_grad(target_link, source_link):
"""Copy gradients of a link to another link."""
target_params = dict(target_link.namedparams())
for param_name, param in source_link.namedparams():
target_params[param_name].grad[:] = param.grad
target_params[param_name].grad[...] = param.grad


def synchronize_parameters(src, dst, method, tau=None):
Expand Down
34 changes: 34 additions & 0 deletions tests/misc_tests/test_copy_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ def test_copy_param(self):
self.assertEqual(a_out_new, b_out)
self.assertEqual(b_out_new, b_out)

def test_copy_param_scalar(self):
a = chainer.Chain()
with a.init_scope():
a.p = chainer.Parameter(np.array(1))
b = chainer.Chain()
with b.init_scope():
b.p = chainer.Parameter(np.array(2))

self.assertNotEqual(a.p.array, b.p.array)

# Copy b's parameters to a
copy_param.copy_param(a, b)

self.assertEqual(a.p.array, b.p.array)

def test_copy_param_type_check(self):
a = L.Linear(None, 5)
b = L.Linear(1, 5)
Expand Down Expand Up @@ -59,6 +74,25 @@ def test_soft_copy_param(self):
np.testing.assert_almost_equal(a.W.array, np.full(a.W.shape, 0.595))
np.testing.assert_almost_equal(b.W.array, np.full(b.W.shape, 1.0))

def test_soft_copy_param_scalar(self):
a = chainer.Chain()
with a.init_scope():
a.p = chainer.Parameter(np.array(0.5))
b = chainer.Chain()
with b.init_scope():
b.p = chainer.Parameter(np.array(1))

# a = (1 - tau) * a + tau * b
copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)

np.testing.assert_almost_equal(a.p.array, 0.55)
np.testing.assert_almost_equal(b.p.array, 1.0)

copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)

np.testing.assert_almost_equal(a.p.array, 0.595)
np.testing.assert_almost_equal(b.p.array, 1.0)

def test_soft_copy_param_type_check(self):
a = L.Linear(None, 5)
b = L.Linear(1, 5)
Expand Down