Skip to content

Conversation

@ProGamerGov
Copy link
Contributor

@ProGamerGov ProGamerGov commented Nov 11, 2020

I've added the first batch of unit tests. I'm not sure if they need to be more in depth / test more stuff? Other unit tests are currently on hold due to potential changes to how the objective, optimization system, and ImageTensor will work. Once @greentfrapp finishes those changes, then we'll write tests for them.

  • Added unit tests for all of the optim transform classes and functions.

  • Added unit tests for FFTImage.

  • Added unit tests for PixelImage.

  • Added unit tests for loading and running pretrained InceptionV1 model.

  • Fixed a number of the optim transforms that hadn't been tested thoroughly yet.

  • Added warning about tqdm package being required. It's pretty much the exact same as we added for PIL / Pillow.

  • Moved some of the general model functions & classes to their own file. These functions and classes can be used on / with other models, so it makes sense to separate them from the Inception 5h model.

  • Added unit tests for all of the _utils/models functions & classes.

  • Added function to get all hookable model layers.

  • Added RedirectedReLU example to Torchvision notebook tutorial.

  • Added a ton of missing type hints.

  • Added functions for calculating the Karhunen-Loève (KLT) matrix of custom image datasets. The current ToRGB transform uses an ImageNet KLT matrix, but users may not always be using models trained on ImageNet. Added new tests the new functions.

  • Readded Ludwig's Conv2dSame with tests and type hints.

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented Nov 11, 2020

==================================== ERRORS ====================================
_______________ ERROR collecting tests/optim/test_transforms.py ________________
ImportError while importing test module '/home/circleci/project/tests/optim/test_transforms.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
/usr/local/lib/python3.6/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
tests/optim/test_transforms.py:7: in <module>
    from captum.optim._param.image import transform
captum/optim/__init__.py:3: in <module>
    from captum.optim._core import objectives  # noqa: F401
captum/optim/_core/objectives.py:9: in <module>
    from tqdm.auto import tqdm
E   ModuleNotFoundError: No module named 'tqdm'

So, I added a try-except block to deal with the missing tqdm import. But it'll have to be dealt with more in-depth for objective and optimization tests.

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented Nov 11, 2020

@NarineK Is adding tqdm and pillow (if it's not already included) to the test suite possible? We might be able to find a workaround for tqdm, but Pillow would be required for testing real image loading. @greentfrapp

This PR doesn't required PIL / Pillow or tqdm, so this is more about future test PRs.

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented Nov 13, 2020

There's a slight issue with indexing tensors with named dimensions. This may present an issue for future PRs until it's fixed.

pytorch/pytorch#47931

@ProGamerGov
Copy link
Contributor Author

@NarineK There shouldn't be anymore major changes to this PR right now, so you can review and merge it!

@ProGamerGov ProGamerGov changed the title Optim-WIP: Preliminary transform unit tests Optim-WIP: Preliminary transform unit tests & more Nov 13, 2020
@NarineK
Copy link
Contributor

NarineK commented Nov 17, 2020

@NarineK There shouldn't be anymore major changes to this PR right now, so you can review and merge it!

Thank you very much for working on this, @ProGamerGov! I'll make a pass this week.

@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented Nov 20, 2020

@NarineK I added type hinting to all the optim forward() and __init___ functions. I also noticed that most of the __init___ & forward() functions in Captum were missing some of or all of their type hints, despite everything else being type hinted. I'm not sure why that is?

I made a pull request to add the ending type hint stuff to all the __init___ functions that were missing them in the master branch: #535 The forward() functions can't as easily be mass fixed.

@ProGamerGov ProGamerGov force-pushed the tests branch 2 times, most recently from da78e9a to f43a11d Compare November 27, 2020 01:12
Copy link
Contributor

@NarineK NarineK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for addressing the comments @ProGamerGov!
I made couple more nit comments.

A workaround when there is no gradient flow from an initial random input.
ReLU layers will block the gradient flow during backpropagation when their
input is less than 0. This means that it can be impossible to visualize a
target without allowing negative values to pass through ReLU layers during
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding description, @ProGamerGov! Is this replacement necessary if the target layer is a ReLU or is this necessary in general for all ReLU layers in the model?

Copy link
Contributor Author

@ProGamerGov ProGamerGov Dec 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is necessary for all ReLU layers in the model, regardless of what the target layer is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanation, @ProGamerGov! Is this related to dead ReLU problem that Ludwig mentioned in his original PR ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a solution to the dead ReLU problem!

def backward(self, grad_output: torch.Tensor) -> torch.Tensor:
(input_tensor,) = self.saved_tensors
grad_input = grad_output.clone()
grad_input[input_tensor < 0] = grad_input[input_tensor < 0] * 1e-1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why is it necessary to multiple by 0.1 here ?

Copy link
Contributor Author

@ProGamerGov ProGamerGov Dec 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure why it's done so I've reached out to @greentfrapp for exactly why gradient scaling is required.

I also ran some tests and found that removing the scaling seemed to allow features that shouldn't be in the visualization (tests are from the InceptionV1 tutorial):

artifacts_3
artifacts_1
artifacts_7
artifacts_6
artifacts_5

artifacts_4

So, it looks like gradient scaling is required if we remove the normal backward ReLU code, but I'm not sure why scaling by 0.1 was chosen vs other numbers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, @greentfrapp just chose 0.1 as a way of handling the gradient flow in a way that wouldn't mess things up. In Lucid, Lubwig let everything through with the same weight / scale for the first 16 optimization steps.

I think @greentfrapp is working on a more advanced version of RedirectedReLU that more closely resembles the one in Lucid, but for right now the current version is sufficient.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanation @ProGamerGov! Perhaps, we can write a sentence documentation about it so that we remember.
In the example above, does each image correspond to a step in the process of the optimization without using 0.1 multiplier ?

Copy link
Contributor Author

@ProGamerGov ProGamerGov Dec 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK The images were just the same ones as in the InceptionV1 tutorial notebook as I reran all the cells in the notebook without multiplying by 0.1 in RedirectedReLU. They aren't different steps. They are different targets and parameters from the tutorial notebook.

@greentfrapp has put together a more a better and advanced version of RedirectedReLU in his pull request, so I'll leave the RedirectedReLU version in this pull request as it currently is for merging. #552


from captum.optim._param.image import transform
from tests.helpers.basic import BaseTest

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that if those transforms already exist in numpy or scipy them we could compare our implementation with theirs but we can do that later.

rr_loss = rr_layer(x * 1).mean()
rr_loss.backward()

assertTensorAlmostEqual(self, t_grad_input[0], t_grad_output[0], 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need t_grad_input and t_grad_output as arrays if we are using only the first element ?

Copy link
Contributor Author

@ProGamerGov ProGamerGov Dec 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lists can be appended from inside the hook function whereas other variables cannot, I think.

@ProGamerGov ProGamerGov requested a review from NarineK December 3, 2020 18:57
@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented Dec 3, 2020

@NarineK I added NumPy comparison tests for 4 of the transforms! The other tests are not easily recreated in NumPy or there is no need for NumPy versions.

@ProGamerGov ProGamerGov force-pushed the tests branch 4 times, most recently from f11fdb4 to aabbd62 Compare December 4, 2020 19:41
@ProGamerGov ProGamerGov force-pushed the tests branch 2 times, most recently from f8da94d to 9a79c78 Compare December 5, 2020 00:17
@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented Dec 9, 2020

@NarineK @greentfrapp and I had to make some changes to the FFTImage classes to handle the PyTorch fft depreciations, but it should be ready for merging once you've finished reviewing the code!

@NarineK
Copy link
Contributor

NarineK commented Dec 9, 2020

Thank you, @ProGamerGov ! I'm merging now in order not to not make other PRs to wait. I will take another look later into more details.

@NarineK NarineK merged commit e319690 into meta-pytorch:optim-wip Dec 9, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants