Skip to content

Commit c211032

Browse files
ProGamerGovfacebook-github-bot
authored andcommitted
Fix version check bug (#940)
Summary: By default: `"1.8.0" > "1.10.0"` will be equal to True, despite 1.10 being a later version that 1.8.0. This PR fixes this issue. Pull Request resolved: #940 Reviewed By: NarineK Differential Revision: D36336547 Pulled By: vivekmig fbshipit-source-id: 84f277eb1e6897a8378ce9eb8c9eab3285ad8494
1 parent a702728 commit c211032

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/utils/test_sample_gradient.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
from captum._utils.sample_gradient import SampleGradientWrapper, SUPPORTED_MODULES
8+
from packaging import version
89
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
910
from tests.helpers.basic_models import (
1011
BasicModel_ConvNet_One_Conv,
@@ -37,7 +38,7 @@ def test_sample_grads_conv_mean_multi_inp(self) -> None:
3738
self._compare_sample_grads_per_sample(model, inp, lambda x: torch.mean(x))
3839

3940
def test_sample_grads_modified_conv_mean(self) -> None:
40-
if torch.__version__ < "1.8":
41+
if version.parse(torch.__version__) < version.parse("1.8.0"):
4142
raise unittest.SkipTest(
4243
"Skipping sample gradient test with 3D linear module"
4344
"since torch version < 1.8"
@@ -50,7 +51,7 @@ def test_sample_grads_modified_conv_mean(self) -> None:
5051
)
5152

5253
def test_sample_grads_modified_conv_sum(self) -> None:
53-
if torch.__version__ < "1.8":
54+
if version.parse(torch.__version__) < version.parse("1.8.0"):
5455
raise unittest.SkipTest(
5556
"Skipping sample gradient test with 3D linear module"
5657
"since torch version < 1.8"

0 commit comments

Comments
 (0)