77
88class TestOptim (unittest .TestCase ):
99 def test_SGD (self ):
10- optim = SGD (torch .nn .Linear (10 , 3 ).parameters ())
10+ optim = SGD (model_params = torch .nn .Linear (10 , 3 ).parameters ())
1111 self .assertTrue ("lr" in optim .__dict__ ["settings" ])
1212 self .assertTrue ("momentum" in optim .__dict__ ["settings" ])
1313 res = optim .construct_from_pytorch (torch .nn .Linear (10 , 3 ).parameters ())
@@ -22,13 +22,18 @@ def test_SGD(self):
2222 self .assertEqual (optim .__dict__ ["settings" ]["lr" ], 0.002 )
2323 self .assertEqual (optim .__dict__ ["settings" ]["momentum" ], 0.989 )
2424
25- with self .assertRaises (RuntimeError ):
25+ optim = SGD (0.001 )
26+ self .assertEqual (optim .__dict__ ["settings" ]["lr" ], 0.001 )
27+ res = optim .construct_from_pytorch (torch .nn .Linear (10 , 3 ).parameters ())
28+ self .assertTrue (isinstance (res , torch .optim .SGD ))
29+
30+ with self .assertRaises (TypeError ):
2631 _ = SGD ("???" )
27- with self .assertRaises (RuntimeError ):
32+ with self .assertRaises (TypeError ):
2833 _ = SGD (0.001 , lr = 0.002 )
2934
3035 def test_Adam (self ):
31- optim = Adam (torch .nn .Linear (10 , 3 ).parameters ())
36+ optim = Adam (model_params = torch .nn .Linear (10 , 3 ).parameters ())
3237 self .assertTrue ("lr" in optim .__dict__ ["settings" ])
3338 self .assertTrue ("weight_decay" in optim .__dict__ ["settings" ])
3439 res = optim .construct_from_pytorch (torch .nn .Linear (10 , 3 ).parameters ())
@@ -42,3 +47,8 @@ def test_Adam(self):
4247 optim = Adam (lr = 0.002 , weight_decay = 0.989 )
4348 self .assertEqual (optim .__dict__ ["settings" ]["lr" ], 0.002 )
4449 self .assertEqual (optim .__dict__ ["settings" ]["weight_decay" ], 0.989 )
50+
51+ optim = Adam (0.001 )
52+ self .assertEqual (optim .__dict__ ["settings" ]["lr" ], 0.001 )
53+ res = optim .construct_from_pytorch (torch .nn .Linear (10 , 3 ).parameters ())
54+ self .assertTrue (isinstance (res , torch .optim .Adam ))
0 commit comments