diff --git a/tests/attr/layer/test_grad_cam.py b/tests/attr/layer/test_grad_cam.py index c089921bc6..66fc1a62a7 100644 --- a/tests/attr/layer/test_grad_cam.py +++ b/tests/attr/layer/test_grad_cam.py @@ -6,7 +6,8 @@ import torch from captum._utils.typing import TensorLikeList from captum.attr._core.layer.grad_cam import LayerGradCam -from tests.helpers.basic import assertTensorTuplesAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorTuplesAlmostEqual from tests.helpers.basic_models import ( BasicModel_ConvNet_One_Conv, BasicModel_MultiLayer, diff --git a/tests/attr/layer/test_layer_lrp.py b/tests/attr/layer/test_layer_lrp.py index b76e186250..f6ad16e96d 100644 --- a/tests/attr/layer/test_layer_lrp.py +++ b/tests/attr/layer/test_layer_lrp.py @@ -8,7 +8,8 @@ from captum.attr import LayerLRP from captum.attr._utils.lrp_rules import Alpha1_Beta0_Rule, EpsilonRule, GammaRule -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv, SimpleLRPModel from torch import Tensor diff --git a/tests/attr/neuron/test_neuron_ablation.py b/tests/attr/neuron/test_neuron_ablation.py index 6556e95702..f1ec506845 100644 --- a/tests/attr/neuron/test_neuron_ablation.py +++ b/tests/attr/neuron/test_neuron_ablation.py @@ -10,7 +10,8 @@ TensorOrTupleOfTensorsGeneric, ) from captum.attr._core.neuron.neuron_feature_ablation import NeuronFeatureAblation -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import ( BasicModel_ConvNet_One_Conv, BasicModel_MultiLayer, diff --git a/tests/attr/neuron/test_neuron_conductance.py b/tests/attr/neuron/test_neuron_conductance.py index 64400d7591..ed63d9b906 100644 --- a/tests/attr/neuron/test_neuron_conductance.py +++ b/tests/attr/neuron/test_neuron_conductance.py @@ -7,7 +7,8 @@ from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric from captum.attr._core.layer.layer_conductance import LayerConductance from captum.attr._core.neuron.neuron_conductance import NeuronConductance -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import ( BasicModel_ConvNet, BasicModel_MultiLayer, diff --git a/tests/attr/neuron/test_neuron_deeplift.py b/tests/attr/neuron/test_neuron_deeplift.py index 8d1435847c..bd1b56c58b 100644 --- a/tests/attr/neuron/test_neuron_deeplift.py +++ b/tests/attr/neuron/test_neuron_deeplift.py @@ -11,7 +11,8 @@ _create_inps_and_base_for_deeplift_neuron_layer_testing, _create_inps_and_base_for_deepliftshap_neuron_layer_testing, ) -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import ( BasicModel_ConvNet, BasicModel_ConvNet_MaxPool3d, diff --git a/tests/attr/neuron/test_neuron_gradient_shap.py b/tests/attr/neuron/test_neuron_gradient_shap.py index f5d2920a0b..c44cc4a51f 100644 --- a/tests/attr/neuron/test_neuron_gradient_shap.py +++ b/tests/attr/neuron/test_neuron_gradient_shap.py @@ -6,7 +6,8 @@ from captum.attr._core.neuron.neuron_integrated_gradients import ( NeuronIntegratedGradients, ) -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicModel_MultiLayer from tests.helpers.classification_models import SoftmaxModel from torch import Tensor diff --git a/tests/attr/test_baselines.py b/tests/attr/test_baselines.py index 0bdb240da1..19f31a24a0 100644 --- a/tests/attr/test_baselines.py +++ b/tests/attr/test_baselines.py @@ -3,7 +3,7 @@ from captum.attr._utils.baselines import ProductBaselines # from parameterized import parameterized -from tests.helpers.basic import BaseTest +from tests.helpers import BaseTest class TestProductBaselines(BaseTest): diff --git a/tests/attr/test_class_summarizer.py b/tests/attr/test_class_summarizer.py index 78403ece11..4c771af330 100644 --- a/tests/attr/test_class_summarizer.py +++ b/tests/attr/test_class_summarizer.py @@ -3,7 +3,7 @@ import torch from captum.attr import ClassSummarizer, CommonStats -from tests.helpers.basic import BaseTest +from tests.helpers import BaseTest class Test(BaseTest): diff --git a/tests/attr/test_common.py b/tests/attr/test_common.py index c2c987e4c1..7b0f1308c1 100644 --- a/tests/attr/test_common.py +++ b/tests/attr/test_common.py @@ -3,7 +3,7 @@ import torch from captum.attr._core.noise_tunnel import SUPPORTED_NOISE_TUNNEL_TYPES from captum.attr._utils.common import _validate_input, _validate_noise_tunnel_type -from tests.helpers.basic import BaseTest +from tests.helpers import BaseTest class Test(BaseTest): diff --git a/tests/attr/test_deconvolution.py b/tests/attr/test_deconvolution.py index 9b54b7b9d4..73aa01abdc 100644 --- a/tests/attr/test_deconvolution.py +++ b/tests/attr/test_deconvolution.py @@ -11,7 +11,8 @@ from captum.attr._core.neuron.neuron_guided_backprop_deconvnet import ( NeuronDeconvolution, ) -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv from torch.nn import Module diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index 91ff63d259..80f8efda42 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -10,7 +10,8 @@ from captum.attr._core.feature_ablation import FeatureAblation from captum.attr._core.noise_tunnel import NoiseTunnel from captum.attr._utils.attribution import Attribution -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import ( BasicModel, BasicModel_ConvNet_One_Conv, diff --git a/tests/attr/test_feature_permutation.py b/tests/attr/test_feature_permutation.py index 4a1a2fc144..e432cd1a2c 100644 --- a/tests/attr/test_feature_permutation.py +++ b/tests/attr/test_feature_permutation.py @@ -3,7 +3,8 @@ import torch from captum.attr._core.feature_permutation import _permute_feature, FeaturePermutation -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicModelWithSparseInputs from torch import Tensor diff --git a/tests/attr/test_gradient_shap.py b/tests/attr/test_gradient_shap.py index 56b0c544b7..7ffd92c1ae 100644 --- a/tests/attr/test_gradient_shap.py +++ b/tests/attr/test_gradient_shap.py @@ -8,7 +8,8 @@ from captum.attr._core.gradient_shap import GradientShap from captum.attr._core.integrated_gradients import IntegratedGradients from numpy import ndarray -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicLinearModel, BasicModel2 from tests.helpers.classification_models import SoftmaxModel diff --git a/tests/attr/test_guided_backprop.py b/tests/attr/test_guided_backprop.py index 01380a8c93..dc3b1b9b5a 100644 --- a/tests/attr/test_guided_backprop.py +++ b/tests/attr/test_guided_backprop.py @@ -9,7 +9,8 @@ from captum.attr._core.neuron.neuron_guided_backprop_deconvnet import ( NeuronGuidedBackprop, ) -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv from torch.nn import Module diff --git a/tests/attr/test_guided_grad_cam.py b/tests/attr/test_guided_grad_cam.py index 8b33e583b6..c74fab0b9c 100644 --- a/tests/attr/test_guided_grad_cam.py +++ b/tests/attr/test_guided_grad_cam.py @@ -6,7 +6,8 @@ import torch from captum._utils.typing import TensorOrTupleOfTensorsGeneric from captum.attr._core.guided_grad_cam import GuidedGradCam -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv from torch import Tensor from torch.nn import Module diff --git a/tests/attr/test_hook_removal.py b/tests/attr/test_hook_removal.py index ce0d0b3316..29e10a5bb1 100644 --- a/tests/attr/test_hook_removal.py +++ b/tests/attr/test_hook_removal.py @@ -14,7 +14,8 @@ should_create_generated_test, ) from tests.attr.helpers.test_config import config -from tests.helpers.basic import BaseTest, deep_copy_args +from tests.helpers import BaseTest +from tests.helpers.basic import deep_copy_args from torch.nn import Module """ diff --git a/tests/attr/test_input_x_gradient.py b/tests/attr/test_input_x_gradient.py index 631424cfd1..056628b5db 100644 --- a/tests/attr/test_input_x_gradient.py +++ b/tests/attr/test_input_x_gradient.py @@ -6,7 +6,8 @@ from captum.attr._core.input_x_gradient import InputXGradient from captum.attr._core.noise_tunnel import NoiseTunnel from tests.attr.test_saliency import _get_basic_config, _get_multiargs_basic_config -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.classification_models import SoftmaxModel from torch import Tensor from torch.nn import Module diff --git a/tests/attr/test_integrated_gradients_basic.py b/tests/attr/test_integrated_gradients_basic.py index 00e2c0fa5e..4b0ce2aa81 100644 --- a/tests/attr/test_integrated_gradients_basic.py +++ b/tests/attr/test_integrated_gradients_basic.py @@ -9,7 +9,8 @@ from captum.attr._core.integrated_gradients import IntegratedGradients from captum.attr._core.noise_tunnel import NoiseTunnel from captum.attr._utils.common import _tensorize_baseline -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import ( BasicModel, BasicModel2, diff --git a/tests/attr/test_integrated_gradients_classification.py b/tests/attr/test_integrated_gradients_classification.py index 8fdd7401d2..5ebac6ce6e 100644 --- a/tests/attr/test_integrated_gradients_classification.py +++ b/tests/attr/test_integrated_gradients_classification.py @@ -6,7 +6,8 @@ from captum._utils.typing import BaselineType, Tensor from captum.attr._core.integrated_gradients import IntegratedGradients from captum.attr._core.noise_tunnel import NoiseTunnel -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.classification_models import SigmoidModel, SoftmaxModel from torch.nn import Module diff --git a/tests/attr/test_interpretable_input.py b/tests/attr/test_interpretable_input.py index 8dc256cb36..78271d0e2d 100644 --- a/tests/attr/test_interpretable_input.py +++ b/tests/attr/test_interpretable_input.py @@ -3,7 +3,8 @@ import torch from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput from parameterized import parameterized -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from torch import Tensor diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index fdc911002e..73dece498d 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -13,7 +13,8 @@ from captum.attr._core.shapley_value import ShapleyValueSampling from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput from parameterized import parameterized, parameterized_class -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from torch import nn, Tensor diff --git a/tests/attr/test_lrp.py b/tests/attr/test_lrp.py index bec5bada89..75b4efbe23 100644 --- a/tests/attr/test_lrp.py +++ b/tests/attr/test_lrp.py @@ -10,7 +10,8 @@ GammaRule, IdentityRule, ) -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import ( BasicModel_ConvNet_One_Conv, BasicModel_MultiLayer, diff --git a/tests/attr/test_occlusion.py b/tests/attr/test_occlusion.py index fd3071bccf..5f45154ec3 100644 --- a/tests/attr/test_occlusion.py +++ b/tests/attr/test_occlusion.py @@ -12,7 +12,8 @@ TensorOrTupleOfTensorsGeneric, ) from captum.attr._core.occlusion import Occlusion -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import ( BasicModel3, BasicModel_ConvNet_One_Conv, diff --git a/tests/attr/test_stat.py b/tests/attr/test_stat.py index 30c5e336b4..0d2d793307 100644 --- a/tests/attr/test_stat.py +++ b/tests/attr/test_stat.py @@ -4,7 +4,8 @@ import torch from captum.attr import Max, Mean, Min, MSE, StdDev, Sum, Summarizer, Var -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual def get_values(n: int = 100, lo=None, hi=None, integers: bool = False): diff --git a/tests/attr/test_summarizer.py b/tests/attr/test_summarizer.py index 67dc2e53e5..db440ddd80 100644 --- a/tests/attr/test_summarizer.py +++ b/tests/attr/test_summarizer.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import torch from captum.attr import CommonStats, Summarizer -from tests.helpers.basic import BaseTest +from tests.helpers import BaseTest class Test(BaseTest): diff --git a/tests/attr/test_utils_batching.py b/tests/attr/test_utils_batching.py index 30c99e1d8d..61d767331b 100644 --- a/tests/attr/test_utils_batching.py +++ b/tests/attr/test_utils_batching.py @@ -6,7 +6,8 @@ _batched_operator, _tuple_splice_range, ) -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual class Test(BaseTest): diff --git a/tests/concept/test_concept.py b/tests/concept/test_concept.py index 2efb336a5a..01074bd261 100644 --- a/tests/concept/test_concept.py +++ b/tests/concept/test_concept.py @@ -5,7 +5,7 @@ import torch from captum.concept._core.concept import Concept from captum.concept._utils.data_iterator import dataset_to_dataloader -from tests.helpers.basic import BaseTest +from tests.helpers import BaseTest from torch.utils.data import IterableDataset diff --git a/tests/concept/test_tcav.py b/tests/concept/test_tcav.py index 247f365a5c..cf2bf11bbf 100644 --- a/tests/concept/test_tcav.py +++ b/tests/concept/test_tcav.py @@ -26,7 +26,8 @@ from captum.concept._utils.classifier import Classifier from captum.concept._utils.common import concepts_to_str from captum.concept._utils.data_iterator import dataset_to_dataloader -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicModel_ConvNet from torch import Tensor from torch.utils.data import DataLoader, IterableDataset diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index e69de29bb2..746a30571f 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 + +try: + from tests.helpers.fb.internal_base import FbBaseTest as BaseTest + + __all__ = [ + "BaseTest", + ] + +except ImportError: + from tests.helpers.basic import BaseTest diff --git a/tests/helpers/basic.py b/tests/helpers/basic.py index 047036fdbf..06a0b7ec51 100644 --- a/tests/helpers/basic.py +++ b/tests/helpers/basic.py @@ -2,6 +2,7 @@ import copy import random import unittest + from typing import Callable import numpy as np diff --git a/tests/influence/_core/test_arnoldi_influence.py b/tests/influence/_core/test_arnoldi_influence.py index e6b4358f7d..875c604baf 100644 --- a/tests/influence/_core/test_arnoldi_influence.py +++ b/tests/influence/_core/test_arnoldi_influence.py @@ -17,7 +17,8 @@ _unflatten_params_factory, ) from parameterized import parameterized -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.influence._utils.common import ( _format_batch_into_tuple, build_test_name_func, diff --git a/tests/influence/_core/test_dataloader.py b/tests/influence/_core/test_dataloader.py index 564ccfa4df..aecf1ae885 100644 --- a/tests/influence/_core/test_dataloader.py +++ b/tests/influence/_core/test_dataloader.py @@ -8,7 +8,8 @@ TracInCPFastRandProj, ) from parameterized import parameterized -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.influence._utils.common import ( _format_batch_into_tuple, build_test_name_func, diff --git a/tests/influence/_core/test_naive_influence.py b/tests/influence/_core/test_naive_influence.py index 0b3265657b..b48a1ffaad 100644 --- a/tests/influence/_core/test_naive_influence.py +++ b/tests/influence/_core/test_naive_influence.py @@ -13,11 +13,8 @@ _unflatten_params_factory, ) from parameterized import parameterized -from tests.helpers.basic import ( - assertTensorAlmostEqual, - assertTensorTuplesAlmostEqual, - BaseTest, -) +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual, assertTensorTuplesAlmostEqual from tests.influence._utils.common import ( _format_batch_into_tuple, build_test_name_func, diff --git a/tests/influence/_core/test_similarity_influence.py b/tests/influence/_core/test_similarity_influence.py index 395762a5b2..de5c0c8b80 100644 --- a/tests/influence/_core/test_similarity_influence.py +++ b/tests/influence/_core/test_similarity_influence.py @@ -8,7 +8,8 @@ euclidean_distance, SimilarityInfluence, ) -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from torch.utils.data import Dataset diff --git a/tests/influence/_core/test_tracin_intermediate_quantities.py b/tests/influence/_core/test_tracin_intermediate_quantities.py index e779c8c2e1..8370a8ca20 100644 --- a/tests/influence/_core/test_tracin_intermediate_quantities.py +++ b/tests/influence/_core/test_tracin_intermediate_quantities.py @@ -12,7 +12,8 @@ TracInCPFastRandProj, ) from parameterized import parameterized -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.influence._utils.common import ( _format_batch_into_tuple, build_test_name_func, diff --git a/tests/influence/_core/test_tracin_k_most_influential.py b/tests/influence/_core/test_tracin_k_most_influential.py index 1044743b94..8d4b38c368 100644 --- a/tests/influence/_core/test_tracin_k_most_influential.py +++ b/tests/influence/_core/test_tracin_k_most_influential.py @@ -6,7 +6,8 @@ from captum.influence._core.tracincp import TracInCP from parameterized import parameterized -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.influence._utils.common import ( _format_batch_into_tuple, build_test_name_func, diff --git a/tests/influence/_core/test_tracin_regression.py b/tests/influence/_core/test_tracin_regression.py index 4665b5ad1a..e6de3a0638 100644 --- a/tests/influence/_core/test_tracin_regression.py +++ b/tests/influence/_core/test_tracin_regression.py @@ -12,7 +12,8 @@ TracInCPFastRandProj, ) from parameterized import parameterized -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.influence._utils.common import ( _isSorted, _wrap_model_in_dataparallel, diff --git a/tests/influence/_core/test_tracin_self_influence.py b/tests/influence/_core/test_tracin_self_influence.py index c6a5844a72..7af3a3d61d 100644 --- a/tests/influence/_core/test_tracin_self_influence.py +++ b/tests/influence/_core/test_tracin_self_influence.py @@ -8,7 +8,8 @@ from captum.influence._core.tracincp import TracInCP, TracInCPBase from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast from parameterized import parameterized -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.influence._utils.common import ( _format_batch_into_tuple, build_test_name_func, diff --git a/tests/influence/_core/test_tracin_show_progress.py b/tests/influence/_core/test_tracin_show_progress.py index cd0e8a46d7..e5d31903e0 100644 --- a/tests/influence/_core/test_tracin_show_progress.py +++ b/tests/influence/_core/test_tracin_show_progress.py @@ -7,7 +7,7 @@ from captum.influence._core.tracincp import TracInCP from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast from parameterized import parameterized -from tests.helpers.basic import BaseTest +from tests.helpers import BaseTest from tests.influence._utils.common import ( build_test_name_func, DataInfluenceConstructor, diff --git a/tests/influence/_core/test_tracin_validation.py b/tests/influence/_core/test_tracin_validation.py index 888a47142a..f24e56d7e1 100644 --- a/tests/influence/_core/test_tracin_validation.py +++ b/tests/influence/_core/test_tracin_validation.py @@ -6,7 +6,7 @@ from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast from parameterized import parameterized -from tests.helpers.basic import BaseTest +from tests.helpers import BaseTest from tests.influence._utils.common import ( build_test_name_func, DataInfluenceConstructor, diff --git a/tests/influence/_core/test_tracin_xor.py b/tests/influence/_core/test_tracin_xor.py index 5c5a0bb760..9f583245cb 100644 --- a/tests/influence/_core/test_tracin_xor.py +++ b/tests/influence/_core/test_tracin_xor.py @@ -8,7 +8,8 @@ import torch.nn.functional as F from captum.influence._core.tracincp import TracInCP from parameterized import parameterized -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.influence._utils.common import ( _wrap_model_in_dataparallel, BasicLinearNet, diff --git a/tests/insights/test_contribution.py b/tests/insights/test_contribution.py index b6928f187e..cf8f2b8aff 100644 --- a/tests/insights/test_contribution.py +++ b/tests/insights/test_contribution.py @@ -8,7 +8,7 @@ from captum.insights import AttributionVisualizer, Batch from captum.insights.attr_vis.app import FilterConfig from captum.insights.attr_vis.features import BaseFeature, FeatureOutput, ImageFeature -from tests.helpers.basic import BaseTest +from tests.helpers import BaseTest class RealFeature(BaseFeature): diff --git a/tests/insights/test_features.py b/tests/insights/test_features.py index 2f2e07cc06..917249189f 100644 --- a/tests/insights/test_features.py +++ b/tests/insights/test_features.py @@ -10,7 +10,7 @@ TextFeature, ) from matplotlib.figure import Figure -from tests.helpers.basic import BaseTest +from tests.helpers import BaseTest class TestTextFeature(BaseTest): diff --git a/tests/metrics/test_infidelity.py b/tests/metrics/test_infidelity.py index ba131d869f..3a0da03553 100644 --- a/tests/metrics/test_infidelity.py +++ b/tests/metrics/test_infidelity.py @@ -12,7 +12,8 @@ Saliency, ) from captum.metrics import infidelity, infidelity_perturb_func_decorator -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import ( BasicModel2, BasicModel4_MultiArgs, diff --git a/tests/metrics/test_sensitivity.py b/tests/metrics/test_sensitivity.py index 6d67a25c3b..2152e0d7ef 100644 --- a/tests/metrics/test_sensitivity.py +++ b/tests/metrics/test_sensitivity.py @@ -13,7 +13,8 @@ ) from captum.metrics import sensitivity_max from captum.metrics._core.sensitivity import default_perturb_func -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import ( BasicModel2, BasicModel4_MultiArgs, diff --git a/tests/module/test_binary_concrete_stochastic_gates.py b/tests/module/test_binary_concrete_stochastic_gates.py index 2c5e605ed8..f3493becc2 100644 --- a/tests/module/test_binary_concrete_stochastic_gates.py +++ b/tests/module/test_binary_concrete_stochastic_gates.py @@ -5,7 +5,8 @@ import torch from captum.module.binary_concrete_stochastic_gates import BinaryConcreteStochasticGates from parameterized import parameterized_class -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual @parameterized_class( diff --git a/tests/module/test_gaussian_stochastic_gates.py b/tests/module/test_gaussian_stochastic_gates.py index 03df56c51f..580b1e7dd6 100644 --- a/tests/module/test_gaussian_stochastic_gates.py +++ b/tests/module/test_gaussian_stochastic_gates.py @@ -6,7 +6,8 @@ import torch from captum.module.gaussian_stochastic_gates import GaussianStochasticGates from parameterized import parameterized_class -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual @parameterized_class( diff --git a/tests/robust/test_FGSM.py b/tests/robust/test_FGSM.py index 19dffdacf1..4db5619de3 100644 --- a/tests/robust/test_FGSM.py +++ b/tests/robust/test_FGSM.py @@ -4,7 +4,8 @@ import torch from captum._utils.typing import TensorOrTupleOfTensorsGeneric from captum.robust import FGSM -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicModel, BasicModel2, BasicModel_MultiLayer from torch import Tensor from torch.nn import CrossEntropyLoss diff --git a/tests/robust/test_PGD.py b/tests/robust/test_PGD.py index 7e39ca99d9..f22c167907 100644 --- a/tests/robust/test_PGD.py +++ b/tests/robust/test_PGD.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 import torch from captum.robust import PGD -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicModel, BasicModel2, BasicModel_MultiLayer from torch.nn import CrossEntropyLoss diff --git a/tests/robust/test_attack_comparator.py b/tests/robust/test_attack_comparator.py index 7585ad8f9c..b2b2d3701a 100644 --- a/tests/robust/test_attack_comparator.py +++ b/tests/robust/test_attack_comparator.py @@ -4,7 +4,8 @@ import torch from captum.robust import AttackComparator, FGSM -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicModel, BasicModel_MultiLayer from torch import Tensor diff --git a/tests/robust/test_min_param_perturbation.py b/tests/robust/test_min_param_perturbation.py index beae331920..8c8a8893ec 100644 --- a/tests/robust/test_min_param_perturbation.py +++ b/tests/robust/test_min_param_perturbation.py @@ -3,7 +3,8 @@ import torch from captum.robust import MinParamPerturbation -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicModel, BasicModel_MultiLayer from torch import Tensor diff --git a/tests/utils/test_av.py b/tests/utils/test_av.py index 301f04ecb9..855831bc35 100644 --- a/tests/utils/test_av.py +++ b/tests/utils/test_av.py @@ -5,7 +5,8 @@ import torch from captum._utils.av import AV -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicLinearReLULinear from torch.utils.data import DataLoader, Dataset diff --git a/tests/utils/test_gradient.py b/tests/utils/test_gradient.py index 2776708b26..57cfe2c554 100644 --- a/tests/utils/test_gradient.py +++ b/tests/utils/test_gradient.py @@ -9,7 +9,8 @@ compute_layer_gradients_and_eval, undo_gradient_requirements, ) -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import ( BasicModel, BasicModel2, diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 46af61b58a..3f38a4f45f 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 import torch -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual class HelpersTest(BaseTest): diff --git a/tests/utils/test_jacobian.py b/tests/utils/test_jacobian.py index 9537c11b72..05e1a77cac 100644 --- a/tests/utils/test_jacobian.py +++ b/tests/utils/test_jacobian.py @@ -6,7 +6,8 @@ _compute_jacobian_wrt_params, _compute_jacobian_wrt_params_with_sample_wise_trick, ) -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import BasicLinearModel2, BasicLinearModel_Multilayer diff --git a/tests/utils/test_linear_model.py b/tests/utils/test_linear_model.py index e937057690..cc1bd23029 100644 --- a/tests/utils/test_linear_model.py +++ b/tests/utils/test_linear_model.py @@ -8,7 +8,8 @@ SGDLinearRegression, SGDRidge, ) -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from torch import Tensor diff --git a/tests/utils/test_progress.py b/tests/utils/test_progress.py index c2fc55c7c7..32446d186a 100644 --- a/tests/utils/test_progress.py +++ b/tests/utils/test_progress.py @@ -5,7 +5,7 @@ import unittest.mock from captum._utils.progress import NullProgress, progress -from tests.helpers.basic import BaseTest +from tests.helpers import BaseTest class Test(BaseTest): diff --git a/tests/utils/test_sample_gradient.py b/tests/utils/test_sample_gradient.py index 2e4bdbf379..9ecdcb445a 100644 --- a/tests/utils/test_sample_gradient.py +++ b/tests/utils/test_sample_gradient.py @@ -11,7 +11,8 @@ SUPPORTED_MODULES, ) from packaging import version -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from tests.helpers.basic_models import ( BasicModel_ConvNet_One_Conv, BasicModel_ConvNetWithPaddingDilation,