diff --git a/.gitignore b/.gitignore index 1c6a56b13..dc3be0ba6 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,7 @@ __pycache__ releasePassword apiPassword venv/ -env/ +./env/ .env .DS_Store bin/ diff --git a/CHANGELOG.md b/CHANGELOG.md index f0e85a675..d4ffabeac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,10 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] -- Upgrades `pip` and `setuptools` in CI publish job -- Also upgrades `poetry` and it's dependency - `clikit` ## [0.29.0] - 2025-03-03 +- Adds option to disable `tldextract` HTTP calls by setting `SUPERTOKENS_TLDEXTRACT_DISABLE_HTTP=1` +- Upgrades `pip` and `setuptools` in CI publish job + - Also upgrades `poetry` and it's dependency - `clikit` - Migrates unit tests to use a containerized core - Updates `Makefile` to use a Docker `compose` setup step - Migrates unit tests from CircleCI to Github Actions diff --git a/dev-requirements.txt b/dev-requirements.txt index bf7fd0e84..c4e1b09c9 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -9,6 +9,7 @@ flask-cors==5.0.0 nest-asyncio==1.6.0 pdoc3==0.11.0 pre-commit==3.5.0 +pyfakefs==5.7.4 pylint==3.2.7 pyright==1.1.393 python-dotenv==1.0.1 diff --git a/setup.py b/setup.py index c6e335ea4..ed233051c 100644 --- a/setup.py +++ b/setup.py @@ -117,7 +117,7 @@ "PyJWT[crypto]>=2.5.0,<3.0.0", "httpx>=0.15.0,<1.0.0", "pycryptodome<3.21.0", - "tldextract<5.1.3", + "tldextract<6.0.0", "asgiref>=3.4.1,<4", "typing_extensions>=4.1.1,<5.0.0", "Deprecated<1.3.0", diff --git a/supertokens_python/env/__init__.py b/supertokens_python/env/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/supertokens_python/env/base.py b/supertokens_python/env/base.py new file mode 100644 index 000000000..e7a2a22c5 --- /dev/null +++ b/supertokens_python/env/base.py @@ -0,0 +1,12 @@ +from os import environ + +from supertokens_python.env.utils import str_to_bool + + +def FLAG_tldextract_disable_http(): + """ + Disable HTTP calls from `tldextract`. + """ + val = environ.get("SUPERTOKENS_TLDEXTRACT_DISABLE_HTTP", "0") + + return str_to_bool(val) diff --git a/supertokens_python/env/utils.py b/supertokens_python/env/utils.py new file mode 100644 index 000000000..74af424fd --- /dev/null +++ b/supertokens_python/env/utils.py @@ -0,0 +1,5 @@ +def str_to_bool(val: str) -> bool: + """ + Convert ENV values to boolean + """ + return val.lower() in ("true", "t", "1") diff --git a/supertokens_python/utils.py b/supertokens_python/utils.py index cffc0a9fb..aa002dd17 100644 --- a/supertokens_python/utils.py +++ b/supertokens_python/utils.py @@ -35,8 +35,9 @@ from urllib.parse import urlparse from httpx import HTTPStatusError, Response -from tldextract import extract # type: ignore +from tldextract import TLDExtract +from supertokens_python.env.base import FLAG_tldextract_disable_http from supertokens_python.framework.django.framework import DjangoFramework from supertokens_python.framework.fastapi.framework import FastapiFramework from supertokens_python.framework.flask.framework import FlaskFramework @@ -288,7 +289,16 @@ def get_top_level_domain_for_same_site_resolution(url: str) -> str: if hostname.startswith("localhost") or is_an_ip_address(hostname): return "localhost" - parsed_url: Any = extract(hostname, include_psl_private_domains=True) + extract = TLDExtract(fallback_to_snapshot=True, include_psl_private_domains=True) + # Explicitly disable HTTP calls, use snapshot bundled into library + if FLAG_tldextract_disable_http(): + extract = TLDExtract( + suffix_list_urls=(), # Ensures no HTTP calls + fallback_to_snapshot=True, + include_psl_private_domains=True, + ) + + parsed_url: Any = extract(hostname) if parsed_url.domain == "": # type: ignore # We need to do this because of https://github.com/supertokens/supertokens-python/issues/394 if hostname.endswith(".amazonaws.com") and parsed_url.suffix == hostname: diff --git a/tests/test_utils.py b/tests/test_utils.py index a3f0ccd43..e5321ccf3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,10 @@ +import os import threading +from contextlib import ExitStack from typing import Any, Dict, List, Union +from unittest.mock import patch -import pytest +from pytest import mark, param, raises from supertokens_python.utils import ( RWMutex, get_top_level_domain_for_same_site_resolution, @@ -9,10 +12,10 @@ is_version_gte, ) -from tests.utils import is_subset +from tests.utils import is_subset, outputs -@pytest.mark.parametrize( +@mark.parametrize( "version,min_minor_version,is_gte", [ ( @@ -72,7 +75,7 @@ def test_util_is_version_gte(version: str, min_minor_version: str, is_gte: bool) HOUR = 60 * MINUTE -@pytest.mark.parametrize( +@mark.parametrize( "ms,out", [ (1 * SECOND, "1 second"), @@ -91,7 +94,7 @@ def test_humanize_time(ms: int, out: str): assert humanize_time(ms) == out -@pytest.mark.parametrize( +@mark.parametrize( "d1,d2,result", [ ({"a": {"b": [1, 2]}, "c": 1}, {"c": 1}, True), @@ -176,7 +179,7 @@ def balance_is_valid(): assert actual_balance == expected_balance, "Incorrect account balance" -@pytest.mark.parametrize( +@mark.parametrize( "url,res", [ ("http://localhost:3001", "localhost"), @@ -196,3 +199,41 @@ def balance_is_valid(): ) def test_tld_for_same_site(url: str, res: str): assert get_top_level_domain_for_same_site_resolution(url) == res + + +@mark.parametrize( + ["internet_disabled", "env_val", "expectation"], + [ + param(True, "False", raises(RuntimeError), id="Internet disabled, flag unset"), + param(True, "True", outputs("google.com"), id="Internet disabled, flag set"), + param(False, "False", outputs("google.com"), id="Internet enabled, flag unset"), + param(False, "True", outputs("google.com"), id="Internet enabled, flag set"), + ], +) +def test_tldextract_http_toggle( + internet_disabled: bool, + env_val: str, + expectation: Any, + # pyfakefs fixture, mocks the filesystem + # Mocking `tldextract`'s cache path does not work in repeated tests + fs: Any, +): + import socket + + # Disable sockets, will raise errors on HTTP calls + socket_patch = patch.object(socket, "socket", side_effect=RuntimeError) + environ_patch = patch.dict( + os.environ, + {"SUPERTOKENS_TLDEXTRACT_DISABLE_HTTP": env_val}, + ) + + stack = ExitStack() + stack.enter_context(environ_patch) + if internet_disabled: + stack.enter_context(socket_patch) + + # if `expectation` is raises, checks for raise + # if `outputs`, value used in `assert` statement + with stack, expectation as expected_output: + output = get_top_level_domain_for_same_site_resolution("https://google.com") + assert output == expected_output diff --git a/tests/utils.py b/tests/utils.py index 39499515a..890925589 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,6 +16,7 @@ # Import AsyncMock import sys +from contextlib import contextmanager from datetime import datetime from functools import lru_cache from http.cookies import SimpleCookie @@ -487,3 +488,17 @@ async def create_users( await manually_create_or_update_user( "public", user["provider"], user["userId"], user["email"], True, None ) + + +@contextmanager +def outputs(val: Any): + """ + Outputs a value to assert. + + Usage: + @mark.parametrize(["input", "expectation"], [(1, outputs(1)), (0, raises(Exception))]) + def test_fn(input, expectation): + with expectation as expected_output: + assert 1 / input == expected_output + """ + yield val