Skip to content

Commit 837168f

Browse files
cicichen01facebook-github-bot
authored andcommitted
Use Python properties to cleanly apply rules to class fields
Summary: As titled. There are strong rules on most of the fields in influence algorithms. Make use of Python properties to make the structure cleaner and enforce rules even when client code tries to change the field. Reviewed By: yucu Differential Revision: D54663278 fbshipit-source-id: 273a82e4707c1b0a7afb8637459b4e34bf381d89
1 parent 2828cd1 commit 837168f

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

captum/influence/_core/tracincp_fast_rand_proj.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,6 @@ def __init__(
186186

187187
# TODO: restore prior state
188188
self.final_fc_layer = final_fc_layer
189-
if isinstance(self.final_fc_layer, str):
190-
self.final_fc_layer = _get_module_from_name(model, self.final_fc_layer)
191-
assert isinstance(self.final_fc_layer, Module)
192189
for param in self.final_fc_layer.parameters():
193190
param.requires_grad = True
194191

@@ -203,6 +200,24 @@ def __init__(
203200
else _check_loss_fn(self, test_loss_fn, "test_loss_fn")
204201
)
205202

203+
@property
204+
def final_fc_layer(self) -> Module:
205+
return self._final_fc_layer
206+
207+
@final_fc_layer.setter
208+
def final_fc_layer(self, layer: Union[Module, str]):
209+
if isinstance(layer, str):
210+
try:
211+
self._final_fc_layer = _get_module_from_name(self.model, layer)
212+
if not isinstance(self._final_fc_layer, Module):
213+
raise Exception("No module found for final_fc_layer")
214+
except Exception as ex:
215+
raise ValueError(
216+
f'Invalid final_fc_layer str: "{layer}" provided!'
217+
) from ex
218+
else:
219+
self._final_fc_layer = layer
220+
206221
@log_usage()
207222
def influence( # type: ignore[override]
208223
self,

tests/influence/_core/test_tracin_validation.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616

1717
class TestTracinValidator(BaseTest):
1818

19-
param_list = []
20-
for reduction, constr in [
19+
param_list = [
2120
(
2221
"none",
2322
DataInfluenceConstructor(TracInCP, name="TracInCP"),
@@ -29,8 +28,7 @@ class TestTracinValidator(BaseTest):
2928
name="TracInCpFast",
3029
),
3130
),
32-
]:
33-
param_list.append((reduction, constr))
31+
]
3432

3533
@parameterized.expand(
3634
param_list,
@@ -64,3 +62,24 @@ def test_tracin_require_inputs_dataset(
6462
)
6563
with self.assertRaisesRegex(AssertionError, "required."):
6664
tracin.influence(None, k=None)
65+
66+
def test_tracincp_fast_rand_proj_inputs(self) -> None:
67+
with tempfile.TemporaryDirectory() as tmpdir:
68+
(
69+
net,
70+
train_dataset,
71+
test_samples,
72+
test_labels,
73+
) = get_random_model_and_data(tmpdir, unpack_inputs=False)
74+
75+
with self.assertRaisesRegex(
76+
ValueError, 'Invalid final_fc_layer str: "invalid_layer" provided!'
77+
):
78+
TracInCPFast(
79+
net,
80+
"invalid_layer",
81+
train_dataset,
82+
tmpdir,
83+
loss_fn=nn.MSELoss(),
84+
batch_size=1,
85+
)

0 commit comments

Comments
 (0)