Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def _get_feature_range_and_mask(self, input, input_mask, **kwargs):
)

def _get_feature_counts(self, inputs, feature_mask, **kwargs):
""" return the numbers of input features """
"""return the numbers of input features"""
if not feature_mask:
return tuple(inp[0].numel() if inp.numel() else 0 for inp in inputs)

Expand Down
2 changes: 1 addition & 1 deletion captum/attr/_core/occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,5 +375,5 @@ def _get_feature_range_and_mask(
return 0, feature_max, None

def _get_feature_counts(self, inputs, feature_mask, **kwargs):
""" return the numbers of possible input features """
"""return the numbers of possible input features"""
return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"])
4 changes: 2 additions & 2 deletions captum/attr/_core/shapley_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def _perturbation_generator(
)

def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval):
""" return the total number of forward evaluations needed """
"""return the total number of forward evaluations needed"""
return math.ceil(total_features / perturbations_per_eval) * n_samples


Expand Down Expand Up @@ -740,7 +740,7 @@ def attribute(
)

def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval):
""" return the total number of forward evaluations needed """
"""return the total number of forward evaluations needed"""
return math.ceil(total_features / perturbations_per_eval) * math.factorial(
total_features
)
4 changes: 2 additions & 2 deletions captum/metrics/_core/infidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def sub_infidelity_perturb_func_decorator(pertub_func: Callable) -> Callable:
def default_perturb_func(
inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None
):
r""""""
r""" """
inputs_perturbed = (
pertub_func(inputs, baselines)
if baselines is not None
Expand Down Expand Up @@ -398,7 +398,7 @@ def _generate_perturbations(
"""

def call_perturb_func():
r""""""
r""" """
baselines_pert = None
inputs_pert: Union[Tensor, Tuple[Tensor, ...]]
if len(inputs_expanded) == 1:
Expand Down
188 changes: 188 additions & 0 deletions tutorials/LRP_TorchVision_Tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
{
"cells": [
{
"source": [
"# LRP Tutorial for Pretrained VGG16 Model"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"This notebook demonstrates how to apply the Layer-Wise Relevance Propagation (LRP) algorithm on a pre-trained VGG16 model using a sample image. The relevance of each pixel is visualized by overlaying them on the example image. Further details regarding the operating principles of LRP can be found at [heatmapping.org](http://heatmapping.org/) and [here](https://www.springerprofessional.de/layer-wise-relevance-propagation-an-overview/17153814).\n",
"\n",
"The tutorial uses the same sample image and rule configuration as in [this](https://git.tu-berlin.de/gmontavon/lrp-tutorial) PyTorch implementation.\n",
"\n",
"\n",
"\n",
"\n",
"Note: Before running this tutorial, please install the torchvision and PIL packages.\n"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from PIL import Image\n",
"from torchvision import models, transforms\n",
"\n",
"from captum.attr import LRP\n",
"from captum.attr import visualization as viz\n",
"from captum.attr._utils.lrp_rules import EpsilonRule, GammaRule"
]
},
{
"source": [
"Loads the sample image and performs the appropriate normalizing steps."
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"img = Image.open('img/lrp/castle.jpg')\n",
"\n",
"transform = transforms.Compose(\n",
" [\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(\n",
" mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
" ), # needed for application of ResNet, VGGNet, ...\n",
" ]\n",
")\n",
"\n",
"X = torch.unsqueeze(transform(img), 0)"
],
"cell_type": "code",
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"source": [
"Loads pre-trained VGG16 model and sets it to eval mode."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = models.vgg16(pretrained=True)\n",
"model.eval()"
]
},
{
"source": [
"Direct generation of LRP attribution. The default Epsilon-Rule is used for every layer. As one see in the generated output image, this does not bring many new insights as the heatmap does not focus well on the image content."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lrp = LRP(model)\n",
"attribution = lrp.attribute(X, target=483, verbose=True) # castle -> 483\n",
"attribution = attribution.squeeze().permute(1, 2, 0).detach().numpy()\n",
"\n",
"_ = viz.visualize_image_attr(\n",
" attribution,\n",
" img,\n",
" method='blended_heat_map',\n",
" sign='all',\n",
" show_colorbar=True,\n",
" title='Overlayed LRP',\n",
")"
]
},
{
"source": [
"But one can assign different rules to every layer. This is a crucial step to get expressive heatmaps. In the literature, one can find recommendations on when to use which layer. Currently implemented in captum are LRP-Epsilon, LRP-0, LRP-Gamma, LRP-Alpha-Beta, and the Identity-Rule.\n",
"\n",
"In the next steps, a list of all layers is generated and a rule is assigned to each one."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"layers = list(model._modules['features']) + list(model._modules['classifier'])\n",
"number_layers = len(layers)\n",
"\n",
"for idx_layer in range(1, number_layers)[::-1]:\n",
" if idx_layer <= 16:\n",
" setattr(layers[idx_layer], 'rule', GammaRule())\n",
" if 17 <= idx_layer <= 30:\n",
" setattr(layers[idx_layer], 'rule', EpsilonRule())\n",
" if idx_layer >= 31:\n",
" setattr(layers[idx_layer], 'rule', EpsilonRule(epsilon=0)) # LRP-0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lrp = LRP(model)\n",
"attribution = lrp.attribute(X, target=483, verbose=True) # castle -> 483\n",
"attribution = attribution.squeeze().permute(1, 2, 0).detach().numpy()\n",
"\n",
"_ = viz.visualize_image_attr(\n",
" attribution,\n",
" img,\n",
" method='blended_heat_map',\n",
" sign='all',\n",
" show_colorbar=True,\n",
" title='Overlayed LRP',\n",
")"
]
},
{
"source": [
"With the verbose parameter, one can check the correct application of the rules in the generated output. As one can see in the generated output image, the heatmap shows clearly positive attributions for the silhouette of the castle. In contrast, the road traffic sign and the lantern are contributing negatively to the class 'castle'."
],
"cell_type": "markdown",
"metadata": {}
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3.9.4 64-bit ('captum_dev': conda)"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.4"
},
"interpreter": {
"hash": "02a67cea564318411a474a51be58cb0b6272c7dedcc002599b7f160613d5e69a"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Binary file added tutorials/img/lrp/castle.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.