Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class FlowErrorTag(Enum):
NONE = "none"
MAGIC_FORMAT = "magic_format"
MAGIC_CODE_INCORRECT = "magic_code_incorrect"
DUPLICATE_EXCHANGE = "duplicate_exchange"
OTHER = "other"


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from __future__ import annotations

import logging
import time

from pydantic import BaseModel
from datetime import datetime
from typing import Optional
from typing import Optional, Set

from microsoft_agents.activity import (
Activity,
Expand Down Expand Up @@ -104,6 +105,20 @@ def __init__(
self._max_attempts,
)

# Public registry of token exchange ids. This set is periodically cleared
# by a background asyncio task to avoid unbounded growth.
self.token_exchange_id_registry: Set[str] = set()

# Background task for periodically clearing the registry. The task is
# created lazily when an asyncio event loop is running.
self._clear_interval_seconds = kwargs.get(
"token_exchange_registry_clear_interval", 10
)
# Track the last time the registry was cleared (epoch seconds). The
# registry will be cleared lazily when an async entrypoint is hit and
# the interval has elapsed.
self._last_registry_clear: float = time.time()

@property
def flow_state(self) -> FlowState:
return self._flow_state.model_copy()
Expand Down Expand Up @@ -246,16 +261,26 @@ async def _continue_from_invoke_verify_state(

async def _continue_from_invoke_token_exchange(
self, activity: Activity
) -> TokenResponse:
) -> tuple[TokenResponse, FlowErrorTag]:
"""Handles the continuation of the flow from an invoke activity for token exchange."""
logger.info("Continuing OAuth flow with token exchange...")
token_exchange_request = activity.value
token_exchange_id = token_exchange_request.get("id")

if token_exchange_id in self.token_exchange_id_registry:
logger.warning(
"Token exchange request with id %s has already been processed",
token_exchange_id,
)
return None, FlowErrorTag.DUPLICATE_EXCHANGE
self.token_exchange_id_registry.add(token_exchange_id)
token_response = await self._user_token_client.user_token.exchange_token(
user_id=self._user_id,
connection_name=self._abs_oauth_connection_name,
channel_id=self._channel_id,
body=token_exchange_request,
)
return token_response
return token_response, FlowErrorTag.NONE

async def continue_flow(self, activity: Activity) -> FlowResponse:
"""Continues the OAuth flow based on the incoming activity.
Expand All @@ -269,7 +294,24 @@ async def continue_flow(self, activity: Activity) -> FlowResponse:
"""
logger.debug("Continuing auth flow...")

# Lazily clear the registry if the configured interval has elapsed.
self._maybe_clear_token_exchange_registry()

if not self._flow_state.is_active():
if (
activity.type == ActivityTypes.invoke
and activity.name == "signin/tokenExchange"
and activity.value.get("id") in self.token_exchange_id_registry
):
logger.debug(
"Token exchange request with id %s has already been processed",
activity.value.get("id"),
)
return FlowResponse(
flow_state=self._flow_state.model_copy(),
token_response=None,
flow_error_tag=FlowErrorTag.DUPLICATE_EXCHANGE,
)
logger.debug("OAuth flow is not active, cannot continue")
self._flow_state.tag = FlowStateTag.FAILURE
return FlowResponse(
Expand All @@ -288,14 +330,20 @@ async def continue_flow(self, activity: Activity) -> FlowResponse:
activity.type == ActivityTypes.invoke
and activity.name == "signin/tokenExchange"
):
token_response = await self._continue_from_invoke_token_exchange(activity)
(
token_response,
flow_error_tag,
) = await self._continue_from_invoke_token_exchange(activity)
else:
raise ValueError(f"Unknown activity type {activity.type}")

if not token_response and flow_error_tag == FlowErrorTag.NONE:
flow_error_tag = FlowErrorTag.OTHER

if flow_error_tag != FlowErrorTag.NONE:
if (
flow_error_tag != FlowErrorTag.NONE
and flow_error_tag != FlowErrorTag.DUPLICATE_EXCHANGE
):
logger.debug("Flow error occurred: %s", flow_error_tag)
self._flow_state.tag = FlowStateTag.CONTINUE
self._use_attempt()
Expand Down Expand Up @@ -340,3 +388,20 @@ async def begin_or_continue_flow(self, activity: Activity) -> FlowResponse:

logger.debug("No active flow, beginning new flow...")
return await self.begin_flow(activity)

def _maybe_clear_token_exchange_registry(self) -> None:
"""Clear the `token_exchange_id_registry` if the configured interval
(seconds) has elapsed since the last clear. This uses the machine
epoch (time.time()) and performs lazy eviction when registry access
occurs instead of running a background task.
"""
now = time.time()

if now - self._last_registry_clear >= self._clear_interval_seconds:
if self.token_exchange_id_registry:
logger.debug(
"Clearing token_exchange_id_registry by epoch check (size=%d)",
len(self.token_exchange_id_registry),
)
self.token_exchange_id_registry.clear()
self._last_registry_clear = now
7 changes: 2 additions & 5 deletions tests/_common/storage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
from abc import ABC
from typing import Any

from microsoft_agents.hosting.core.storage import (
Storage,
StoreItem,
MemoryStorage
)
from microsoft_agents.hosting.core.storage import Storage, StoreItem, MemoryStorage
from microsoft_agents.hosting.core.storage._type_aliases import JSON


class MockStoreItem(StoreItem):
"""Test implementation of StoreItem for testing purposes"""

Expand Down
81 changes: 79 additions & 2 deletions tests/hosting_core/test_oauth_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


class TestOAuthFlowUtils:

def create_user_token_client(self, mocker, get_token_return=None):

user_token_client = mocker.Mock(spec=UserTokenClientBase)
Expand Down Expand Up @@ -104,7 +103,6 @@ def flow(self, sample_flow_state, user_token_client):


class TestOAuthFlow(TestOAuthFlowUtils):

def test_init_no_user_token_client(self, sample_flow_state):
with pytest.raises(ValueError):
OAuthFlow(sample_flow_state, None)
Expand Down Expand Up @@ -602,3 +600,82 @@ async def test_begin_or_continue_flow_completed_flow_state(self, mocker):
assert actual_response == expected_response
OAuthFlow.begin_flow.assert_not_called()
OAuthFlow.continue_flow.assert_not_called()

@pytest.mark.asyncio
async def test_token_exchange_dedupe_prevents_replay(
self, mocker, sample_active_flow_state, user_token_client
):
# setup
token_exchange_request = {"id": "exchange-1"}
user_token_client.user_token.exchange_token = mocker.AsyncMock(
return_value=TokenResponse(token=RES_TOKEN)
)
activity = self.create_activity(
mocker,
ActivityTypes.invoke,
name="signin/tokenExchange",
value=token_exchange_request,
)

flow = OAuthFlow(sample_active_flow_state, user_token_client)

# first request should be processed
response1 = await flow.continue_flow(activity)
user_token_client.user_token.exchange_token.assert_called_once_with(
user_id=sample_active_flow_state.user_id,
connection_name=sample_active_flow_state.connection,
channel_id=sample_active_flow_state.channel_id,
body=token_exchange_request,
)
assert response1.token_response == TokenResponse(token=RES_TOKEN)
# registry should contain the processed id
assert "exchange-1" in flow.token_exchange_id_registry

# second request with same id should be ignored (no additional call)
response2 = await flow.continue_flow(activity)
# still only called once
assert user_token_client.user_token.exchange_token.call_count == 1
assert response2.token_response == None
assert response2.flow_error_tag == FlowErrorTag.DUPLICATE_EXCHANGE

@pytest.mark.asyncio
async def test_token_exchange_registry_clears_after_interval(
self, mocker, sample_active_flow_state, user_token_client
):
# setup
token_exchange_request = {"id": "exchange-2"}
user_token_client.user_token.exchange_token = mocker.AsyncMock(
return_value=TokenResponse(token=RES_TOKEN)
)
activity = self.create_activity(
mocker,
ActivityTypes.invoke,
name="signin/tokenExchange",
value=token_exchange_request,
)

flow = OAuthFlow(sample_active_flow_state, user_token_client)

# first request should be processed
response1 = await flow.continue_flow(activity)
assert user_token_client.user_token.exchange_token.call_count == 1
assert response1.token_response == TokenResponse(token=RES_TOKEN)
# registry should contain the processed id
assert "exchange-2" in flow.token_exchange_id_registry

# simulate passage of time beyond the clear interval so the registry is cleared
import time as _time

flow._last_registry_clear = _time.time() - (flow._clear_interval_seconds + 100)

# explicitly invoke the lazy clear helper to simulate the moment when
# the registry would be cleared and assert it was removed.
flow._maybe_clear_token_exchange_registry()
flow._flow_state.tag = FlowStateTag.CONTINUE # keep it active
assert "exchange-2" not in flow.token_exchange_id_registry

# second request should now be processed again (registry was lazily cleared)

response2 = await flow.continue_flow(activity)
assert user_token_client.user_token.exchange_token.call_count == 2
assert response2.token_response == TokenResponse(token=RES_TOKEN)