Skip to content

Commit 50167ad

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Ablation Device Fix (#528)
Summary: Pull Request resolved: #528 Allows output to be on a different device than input by moving output difference to input device. Reviewed By: miguelmartin75 Differential Revision: D25001273 fbshipit-source-id: a9b6d8e8bb585d5360c53272a5502f4e8f257459
1 parent c5907b5 commit 50167ad

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def attribute(
346346
eval_diff = (
347347
initial_eval - modified_eval.reshape((-1, num_outputs))
348348
).reshape((-1, num_outputs) + (len(inputs[i].shape) - 1) * (1,))
349+
eval_diff = eval_diff.to(total_attrib[i].device)
349350
if self.use_weights:
350351
weights[i] += current_mask.float().sum(dim=0)
351352
total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(

0 commit comments

Comments
 (0)