diff --git a/docs/sphinx/conf.py b/docs/sphinx/conf.py index 946acf7..d7c53e8 100644 --- a/docs/sphinx/conf.py +++ b/docs/sphinx/conf.py @@ -45,4 +45,5 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), + "requests": ("https://docs.python-requests.org/en/master", None), } diff --git a/elastic_transport/_models.py b/elastic_transport/_models.py index d886219..df70157 100644 --- a/elastic_transport/_models.py +++ b/elastic_transport/_models.py @@ -232,7 +232,7 @@ class NodeConfig: connections_per_node: int = 10 #: Number of seconds to wait before a request should timeout. - request_timeout: Optional[int] = 10 + request_timeout: Optional[float] = 10.0 #: Set to ``True`` to enable HTTP compression #: of request and response bodies via gzip. @@ -278,10 +278,10 @@ class NodeConfig: #: issued when using ``verify_certs=False``. ssl_show_warn: bool = True - # Extras that can be set to anything, typically used - # for annotating this node with additional information for - # future decisions like sniffing, instance roles, etc. - # Third-party keys should start with an underscore and prefix. + #: Extras that can be set to anything, typically used + #: for annotating this node with additional information for + #: future decisions like sniffing, instance roles, etc. + #: Third-party keys should start with an underscore and prefix. _extras: Dict[str, Any] = field(default_factory=dict, hash=False) def replace(self, **kwargs: Any) -> "NodeConfig": diff --git a/elastic_transport/_node/_http_requests.py b/elastic_transport/_node/_http_requests.py index 6716f4b..89978f6 100644 --- a/elastic_transport/_node/_http_requests.py +++ b/elastic_transport/_node/_http_requests.py @@ -37,6 +37,7 @@ try: import requests from requests.adapters import HTTPAdapter + from requests.auth import AuthBase _REQUESTS_AVAILABLE = True _REQUESTS_META_VERSION = client_meta_version(requests.__version__) @@ -79,7 +80,12 @@ def init_poolmanager( class RequestsHttpNode(BaseNode): - """Synchronous node using the ``requests`` library communicating via HTTP""" + """Synchronous node using the ``requests`` library communicating via HTTP. + + Supports setting :attr:`requests.Session.auth` via the + :attr:`elastic_transport.NodeConfig._extras` + using the ``requests.session.auth`` key. + """ _CLIENT_META_HTTP_CLIENT = ("rq", _REQUESTS_META_VERSION) @@ -96,6 +102,16 @@ def __init__(self, config: NodeConfig): self.session.headers.clear() # Empty out all the default session headers self.session.verify = config.verify_certs + # Requests supports setting 'session.auth' via _extras['requests.session.auth'] = ... + try: + requests_session_auth: Optional[AuthBase] = config._extras.pop( + "requests.session.auth", None + ) + except AttributeError: + requests_session_auth = None + if requests_session_auth is not None: + self.session.auth = requests_session_auth + # Client certificates if config.client_cert: if config.client_key: diff --git a/tests/node/test_http_requests.py b/tests/node/test_http_requests.py index 9bfaaf6..2ff7bc0 100644 --- a/tests/node/test_http_requests.py +++ b/tests/node/test_http_requests.py @@ -22,13 +22,14 @@ import pytest import requests from mock import Mock, patch +from requests.auth import HTTPBasicAuth from elastic_transport import NodeConfig, RequestsHttpNode from elastic_transport._node._base import DEFAULT_USER_AGENT class TestRequestsHttpNode: - def _get_mode_node(self, node_config, response_body=b"{}"): + def _get_mock_node(self, node_config, response_body=b"{}"): node = RequestsHttpNode(node_config) def _dummy_send(*args, **kwargs): @@ -69,7 +70,7 @@ def test_ssl_context(self): assert adapter.poolmanager.connection_pool_kw["ssl_context"] is ctx def test_merge_headers(self): - node = self._get_mode_node( + node = self._get_mock_node( NodeConfig("http", "localhost", 80, headers={"h1": "v1", "h2": "v2"}) ) req = self._get_request(node, "GET", "/", headers={"h2": "v2p", "h3": "v3"}) @@ -78,7 +79,7 @@ def test_merge_headers(self): assert req.headers["h3"] == "v3" def test_default_headers(self): - node = self._get_mode_node(NodeConfig("http", "localhost", 80)) + node = self._get_mock_node(NodeConfig("http", "localhost", 80)) req = self._get_request(node, "GET", "/") assert req.headers == { "connection": "keep-alive", @@ -86,7 +87,7 @@ def test_default_headers(self): } def test_no_http_compression(self): - node = self._get_mode_node( + node = self._get_mock_node( NodeConfig("http", "localhost", 80, http_compress=False) ) assert not node.config.http_compress @@ -108,7 +109,7 @@ def test_no_http_compression(self): @pytest.mark.parametrize("empty_body", [None, b""]) def test_http_compression(self, empty_body): - node = self._get_mode_node( + node = self._get_mock_node( NodeConfig("http", "localhost", 80, http_compress=True) ) assert node.config.http_compress is True @@ -135,7 +136,7 @@ def test_http_compression(self, empty_body): @pytest.mark.parametrize("request_timeout", [None, 15]) def test_timeout_override_default(self, request_timeout): - node = self._get_mode_node( + node = self._get_mock_node( NodeConfig("http", "localhost", 80, request_timeout=request_timeout) ) assert node.config.request_timeout == request_timeout @@ -214,8 +215,23 @@ def test_ca_certs_is_used_as_session_verify(self): def test_surrogatepass_into_bytes(self): data = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa" - con = self._get_mode_node( + node = self._get_mock_node( NodeConfig("http", "localhost", 80), response_body=data ) - _, data = con.perform_request("GET", "/") + _, data = node.perform_request("GET", "/") assert b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa" == data + + @pytest.mark.parametrize("_extras", [None, {}, {"requests.session.auth": None}]) + def test_requests_no_session_auth(self, _extras): + node = self._get_mock_node(NodeConfig("http", "localhost", 80, _extras=_extras)) + assert node.session.auth is None + + def test_requests_custom_auth(self): + auth = HTTPBasicAuth("username", "password") + node = self._get_mock_node( + NodeConfig("http", "localhost", 80, _extras={"requests.session.auth": auth}) + ) + assert node.session.auth is auth + node.perform_request("GET", "/") + (request,), _ = node.session.send.call_args + assert request.headers["authorization"] == "Basic dXNlcm5hbWU6cGFzc3dvcmQ="