Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class MyAgentFactory(AgentFactory):
return create_react_agent(
llm,
[magic_number_tool],
messages_modifier="""You are a helpful assistant.""",
prompt="""You are a helpful assistant.""",
)

def create_llm(self, dto: CreateLLMDto) -> Runnable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def invoke(self, messages: List[OpenAIChatMessage]) -> dict:

return self.invoke_adapter.to_chat_completion_object(result).dict()

async def ainvoke(self, messages: List[OpenAIChatMessage]) -> dict:
input = self.__to_input(messages)

result = await self.agent.ainvoke(
input=input,
)

return self.invoke_adapter.to_chat_completion_object(result).dict()

def __to_input(self, messages: List[OpenAIChatMessage]):
if isinstance(self.agent, CompiledGraph):
return self.__to_react_agent_input(messages)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, ClassVar
from langchain_core.messages import BaseMessage
from langchain_anthropic import ChatAnthropic

Expand All @@ -9,7 +9,7 @@

class AnthropicOpenAICompatibleChatModel(ChatAnthropic):

adapter = AnthropicOpenAICompatibleChatModelAdapter()
adapter: ClassVar[AnthropicOpenAICompatibleChatModelAdapter] = AnthropicOpenAICompatibleChatModelAdapter()

def _stream(self, messages: List[List[BaseMessage]], **kwargs):
transformed_messages = self.adapter.to_openai_format_messages(messages)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import List
from typing import List, ClassVar
from langchain_core.messages import BaseMessage
from langchain_llamacpp_chat_model import LlamaChatModel
from langchain_community.chat_models.llamacpp import ChatLlamaCpp

from langchain_openai_api_bridge.chat_model_adapter.llamacpp.llamacpp_openai_compatible_chat_model_adapter import (
LlamacppOpenAICompatibleChatModelAdapter,
)


class LLamacppOpenAICompatibleChatModel(LlamaChatModel):
class LLamacppOpenAICompatibleChatModel(ChatLlamaCpp):

adapter = LlamacppOpenAICompatibleChatModelAdapter()
adapter: ClassVar[LlamacppOpenAICompatibleChatModelAdapter] = LlamacppOpenAICompatibleChatModelAdapter()

def _stream(self, messages: List[List[BaseMessage]], **kwargs):
transformed_messages = self.adapter.to_openai_format_messages(messages)
Expand Down
2 changes: 1 addition & 1 deletion langchain_openai_api_bridge/core/base_agent_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
class BaseAgentFactory(ABC):

@abstractmethod
def create_agent(self, dto: CreateAgentDto) -> Runnable:
async def acreate_agent(self, dto: CreateAgentDto) -> Runnable:
pass
16 changes: 12 additions & 4 deletions langchain_openai_api_bridge/core/function_agent_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable
import inspect
from typing import Callable, Union, Awaitable
from langchain_core.runnables import Runnable
from langchain_openai_api_bridge.core.base_agent_factory import BaseAgentFactory
from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto
Expand All @@ -8,9 +9,16 @@ class FunctionAgentFactory(BaseAgentFactory):

def __init__(
self,
fn: Callable[[CreateAgentDto], Runnable],
fn: Union[
Callable[[CreateAgentDto], Runnable],
Callable[[CreateAgentDto], Awaitable[Runnable]]
],
) -> None:
self.fn = fn
self.is_async = inspect.iscoroutinefunction(fn)

def create_agent(self, dto: CreateAgentDto) -> Runnable:
return self.fn(dto)
async def acreate_agent(self, dto: CreateAgentDto) -> Runnable:
if self.is_async:
return await self.fn(dto)
else:
return self.fn(dto)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Awaitable
from langchain_openai_api_bridge.core.base_agent_factory import BaseAgentFactory
from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto
from langchain_openai_api_bridge.core.function_agent_factory import FunctionAgentFactory
Expand All @@ -12,7 +12,9 @@ def __init__(
self,
agent_factory_provider: Union[
Callable[[], BaseAgentFactory],
Callable[[], Awaitable[BaseAgentFactory]],
Callable[[CreateAgentDto], Runnable],
Callable[[CreateAgentDto], Awaitable[Runnable]],
BaseAgentFactory,
],
tiny_di_container: Optional[TinyDIContainer] = None,
Expand All @@ -30,7 +32,9 @@ def __init__(
def __is_callable_runnable_provider(
agent_factory_provider: Union[
Callable[[], BaseAgentFactory],
Callable[[], Awaitable[BaseAgentFactory]],
Callable[[CreateAgentDto], Runnable],
Callable[[CreateAgentDto], Awaitable[Runnable]],
BaseAgentFactory,
],
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def assistant_create_thread_runs(

api_key = get_bearer_token(authorization)
agent_factory = tiny_di_container.resolve(InternalAgentFactory)
agent = agent_factory.create_agent(
agent = await agent_factory.acreate_agent(
thread_run_dto=thread_run_dto, api_key=api_key
)

Expand Down
5 changes: 3 additions & 2 deletions langchain_openai_api_bridge/fastapi/chat_completion_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def assistant_retreive_thread_messages(
temperature=request.temperature,
)

agent = agent_factory.create_agent(dto=create_agent_dto)
agent = await agent_factory.acreate_agent(dto=create_agent_dto)

adapter = ChatCompletionCompatibleAPI.from_agent(agent, create_agent_dto.model, event_adapter=event_adapter)

Expand All @@ -41,7 +41,8 @@ async def assistant_retreive_thread_messages(
stream = adapter.astream(request.messages)
return response_factory.to_streaming_response(stream)
else:
return JSONResponse(content=adapter.invoke(request.messages))
content = await adapter.ainvoke(request.messages)
return JSONResponse(content=content)

return chat_completion_router

Expand Down
4 changes: 2 additions & 2 deletions langchain_openai_api_bridge/fastapi/internal_agent_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ class InternalAgentFactory:
def __init__(self, agent_factory: BaseAgentFactory) -> None:
self.agent_factory = agent_factory

def create_agent(self, thread_run_dto: ThreadRunsDto, api_key: str) -> Runnable:
async def acreate_agent(self, thread_run_dto: ThreadRunsDto, api_key: str) -> Runnable:
create_agent_dto = CreateAgentDto(
model=thread_run_dto.model,
thread_id=thread_run_dto.thread_id,
api_key=api_key,
temperature=thread_run_dto.temperature,
assistant_id=thread_run_dto.assistant_id,
)
return self.agent_factory.create_agent(dto=create_agent_dto)
return await self.agent_factory.acreate_agent(dto=create_agent_dto)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Awaitable
from fastapi import FastAPI

from langchain_openai_api_bridge.assistant.adapter.container import (
Expand Down Expand Up @@ -45,7 +45,9 @@ def __init__(
app: FastAPI,
agent_factory_provider: Union[
Callable[[], BaseAgentFactory],
Callable[[], Awaitable[BaseAgentFactory]],
Callable[[CreateAgentDto], Runnable],
Callable[[CreateAgentDto], Awaitable[Runnable]],
BaseAgentFactory,
],
) -> None:
Expand Down
Loading