Commit 3b1aa94
update tracin influence API (#1072)
Summary:
Pull Request resolved: #1072
This diff changes the API for implementations of `TracInCPBase` as discussed in https://fb.quip.com/JbpnAiWluZmI. In particular, the arguments representing test data of the `influence` method are changed from `inputs: Tuple, targets: Optional[Tensor]` to `inputs: Union[Tuple[Any], DataLoader]`, which is either a single batch, or a dataloader yielding batches. In both cases, `model(*batch)` is assumed to produce the predictions for a batch, and `batch[-1]` is assumed to be the labels for a batch. This is the same format assumed of the batches yielded by `train_dataloader`.
We make this change for 2 reasons
- it unifies the assumptions made of the test data and the assumptions made of the training data
- for some implementations, we want to allow the test data to be represented by a dataloader. with the old API, there was no clean way to allow both a single as well as a dataloader to be passed in, since a batch required 2 arguments, but a dataloader only requires 1.
For now, all implementations only allow `inputs` to be a tuple (and not a dataloader). This is okay due to inheritance rules. Later on, we will allow some implementations (i.e. `TracInCP`) to accept a dataloader as `inputs`.
Other changes:
- changes to make documentation. for example, documentation in `TracInCPBase.influence` now refers to the "test dataset" instead of test batch.
- the `unpack_inputs` argument is no longer needed for the `influence` methods, and is removed
- the usage of `influence` in all the tests is changed to match new API.
- signature of helper methods `_influence_batch_tracincp` and `_influence_batch_tracincp_fast` are changed to match new representation of batches.
Reviewed By: cyrjano
Differential Revision: D41324297
fbshipit-source-id: c5834f74e301b4ccbbc2cc0b9f331455ff04a4b21 parent 43c7d23 commit 3b1aa94
File tree
14 files changed
+529
-608
lines changed- captum/influence
- _core
- _utils
- tests/influence
- _core
- _utils
- tutorials
14 files changed
+529
-608
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
15 | | - | |
16 | | - | |
17 | | - | |
| 15 | + | |
18 | 16 | | |
19 | 17 | | |
20 | 18 | | |
21 | | - | |
| 19 | + | |
22 | 20 | | |
23 | 21 | | |
24 | 22 | | |
25 | 23 | | |
26 | 24 | | |
27 | 25 | | |
28 | 26 | | |
29 | | - | |
| 27 | + | |
30 | 28 | | |
31 | 29 | | |
32 | 30 | | |
| |||
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
191 | 191 | | |
192 | 192 | | |
193 | 193 | | |
194 | | - | |
195 | 194 | | |
196 | 195 | | |
197 | 196 | | |
| |||
206 | 205 | | |
207 | 206 | | |
208 | 207 | | |
209 | | - | |
| 208 | + | |
210 | 209 | | |
211 | | - | |
212 | | - | |
213 | | - | |
214 | | - | |
215 | | - | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
216 | 214 | | |
217 | 215 | | |
218 | 216 | | |
| |||
274 | 272 | | |
275 | 273 | | |
276 | 274 | | |
277 | | - | |
| 275 | + | |
278 | 276 | | |
279 | 277 | | |
280 | 278 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| 13 | + | |
13 | 14 | | |
14 | 15 | | |
15 | 16 | | |
| |||
76 | 77 | | |
77 | 78 | | |
78 | 79 | | |
79 | | - | |
| 80 | + | |
| 81 | + | |
80 | 82 | | |
81 | 83 | | |
82 | 84 | | |
| |||
88 | 90 | | |
89 | 91 | | |
90 | 92 | | |
91 | | - | |
| 93 | + | |
| 94 | + | |
92 | 95 | | |
93 | 96 | | |
94 | 97 | | |
| |||
Lines changed: 5 additions & 16 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
15 | 16 | | |
16 | 17 | | |
17 | 18 | | |
| |||
224 | 225 | | |
225 | 226 | | |
226 | 227 | | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
227 | 231 | | |
228 | | - | |
| 232 | + | |
229 | 233 | | |
230 | 234 | | |
231 | | - | |
232 | | - | |
233 | | - | |
234 | | - | |
235 | | - | |
236 | | - | |
237 | | - | |
238 | | - | |
239 | | - | |
240 | | - | |
241 | | - | |
242 | | - | |
243 | | - | |
244 | | - | |
245 | | - | |
246 | 235 | | |
247 | 236 | | |
248 | 237 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
| 11 | + | |
11 | 12 | | |
12 | 13 | | |
13 | 14 | | |
| |||
107 | 108 | | |
108 | 109 | | |
109 | 110 | | |
110 | | - | |
| 111 | + | |
| 112 | + | |
111 | 113 | | |
112 | 114 | | |
113 | 115 | | |
114 | | - | |
115 | | - | |
| 116 | + | |
116 | 117 | | |
117 | 118 | | |
118 | | - | |
119 | 119 | | |
120 | 120 | | |
121 | 121 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
183 | 183 | | |
184 | 184 | | |
185 | 185 | | |
186 | | - | |
| 186 | + | |
187 | 187 | | |
188 | | - | |
| 188 | + | |
189 | 189 | | |
190 | 190 | | |
191 | 191 | | |
192 | 192 | | |
193 | 193 | | |
194 | 194 | | |
195 | 195 | | |
196 | | - | |
| 196 | + | |
197 | 197 | | |
198 | | - | |
| 198 | + | |
199 | 199 | | |
200 | 200 | | |
201 | 201 | | |
| |||
226 | 226 | | |
227 | 227 | | |
228 | 228 | | |
229 | | - | |
| 229 | + | |
230 | 230 | | |
231 | | - | |
| 231 | + | |
232 | 232 | | |
233 | 233 | | |
234 | 234 | | |
235 | 235 | | |
236 | 236 | | |
237 | | - | |
| 237 | + | |
238 | 238 | | |
239 | | - | |
| 239 | + | |
240 | 240 | | |
241 | 241 | | |
242 | 242 | | |
| |||
288 | 288 | | |
289 | 289 | | |
290 | 290 | | |
291 | | - | |
| 291 | + | |
292 | 292 | | |
293 | 293 | | |
294 | 294 | | |
| |||
382 | 382 | | |
383 | 383 | | |
384 | 384 | | |
385 | | - | |
| 385 | + | |
386 | 386 | | |
387 | | - | |
| 387 | + | |
388 | 388 | | |
389 | 389 | | |
390 | 390 | | |
| |||
415 | 415 | | |
416 | 416 | | |
417 | 417 | | |
418 | | - | |
| 418 | + | |
419 | 419 | | |
420 | | - | |
| 420 | + | |
421 | 421 | | |
422 | 422 | | |
423 | 423 | | |
| |||
496 | 496 | | |
497 | 497 | | |
498 | 498 | | |
499 | | - | |
| 499 | + | |
500 | 500 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
| 11 | + | |
11 | 12 | | |
12 | 13 | | |
13 | 14 | | |
| |||
108 | 109 | | |
109 | 110 | | |
110 | 111 | | |
111 | | - | |
112 | | - | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
113 | 115 | | |
114 | | - | |
115 | 116 | | |
116 | 117 | | |
117 | 118 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
178 | 178 | | |
179 | 179 | | |
180 | 180 | | |
181 | | - | |
182 | | - | |
| 181 | + | |
183 | 182 | | |
184 | 183 | | |
185 | 184 | | |
| |||
196 | 195 | | |
197 | 196 | | |
198 | 197 | | |
199 | | - | |
200 | | - | |
| 198 | + | |
201 | 199 | | |
202 | 200 | | |
203 | 201 | | |
| |||
218 | 216 | | |
219 | 217 | | |
220 | 218 | | |
221 | | - | |
222 | | - | |
| 219 | + | |
223 | 220 | | |
224 | 221 | | |
225 | 222 | | |
| |||
0 commit comments