Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class MyAgentFactory(AgentFactory):
return create_react_agent(
llm,
[magic_number_tool],
messages_modifier="""You are a helpful assistant.""",
prompt="""You are a helpful assistant.""",
)

def create_llm(self, dto: CreateLLMDto) -> Runnable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def handle(
step_id = event["run_id"]
name = event["name"]
arguments = event["data"]["input"]
metadata = event.get("metadata", None)
metadata = event.get("metadata", {})
tool_created_event = create_langchain_tool_run_step_tools_created(
step_id=step_id,
assistant_id=dto.assistant_id,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import time
from typing import Literal, Optional, Union
from typing import Any, Literal, Optional, Union
from openai.types.beta.assistant_stream_event import (
ThreadMessageCreated,
ThreadMessageDelta,
Expand All @@ -26,6 +26,7 @@
FromLanggraphMessageChunkContent,
create_text_message_delta,
)
from pydantic import BaseModel


def create_thread_message_created_event(message: Message) -> ThreadMessageCreated:
Expand Down Expand Up @@ -63,12 +64,12 @@ def create_langchain_tool_run_step_tools_created(
assistant_id: str,
thread_id: str,
status: Literal["in_progress", "cancelled", "failed", "completed", "expired"],
metadata: Optional[object] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
arguments: Optional[dict[object]] = None,
output: Optional[str] = None,
) -> ThreadRunStepDelta:

metadata = {key: str(value) for key, value in metadata.items()}
return ThreadRunStepCreated(
event="thread.run.step.created",
data=create_langchain_tool_run_step(
Expand All @@ -89,11 +90,12 @@ def create_langchain_tool_thread_run_step_completed(
assistant_id: str,
thread_id: str,
status: Literal["in_progress", "cancelled", "failed", "completed", "expired"],
metadata: Optional[object] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
arguments: Optional[Union[dict[object], float, str]] = None,
output: Optional[Union[dict[object], float, str]] = None,
) -> ThreadRunStepCompleted:
metadata = {key: str(value) for key, value in metadata.items()}
return ThreadRunStepCompleted(
event="thread.run.step.completed",
data=create_langchain_tool_run_step(
Expand All @@ -114,11 +116,12 @@ def create_langchain_tool_run_step(
assistant_id: str,
thread_id: str,
status: Literal["in_progress", "cancelled", "failed", "completed", "expired"],
metadata: Optional[object] = None,
metadata: Optional[dict[str, str] | BaseModel] = None,
name: Optional[str] = None,
arguments: Optional[Union[dict[object], float, str]] = None,
output: Optional[Union[dict[object], float, str]] = None,
) -> RunStep:
metadata = metadata.model_dump() if isinstance(metadata, BaseModel) else metadata
return RunStep(
id=step_id,
assistant_id=assistant_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
TextContentBlock,
Text,
)

from pydantic import BaseModel
from langchain_openai_api_bridge.assistant.adapter.openai_message_content_adapter import (
to_openai_message_content_list,
)
Expand All @@ -33,9 +33,9 @@ def create_message(
content: Union[str, Iterable[MessageContentPartParam], None] = None,
status: Literal["in_progress", "incomplete", "completed"] = "completed",
run_id: Optional[str] = None,
metadata: Optional[object] = {},
metadata: Optional[dict[str, str] | BaseModel] = None,
) -> Message:

metadata = metadata.model_dump() if isinstance(metadata, BaseModel) else metadata
return Message(
id=id,
thread_id=thread_id,
Expand Down
31 changes: 17 additions & 14 deletions langchain_openai_api_bridge/assistant/assistant_run_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AsyncIterator
from typing import AsyncIterator, AsyncContextManager
from openai.types.beta import AssistantStreamEvent
from langchain_openai_api_bridge.assistant.adapter.langgraph_event_to_openai_assistant_event_stream import (
LanggraphEventToOpenAIAssistantEventStream,
Expand Down Expand Up @@ -37,24 +37,27 @@ def create(self, dto: ThreadRunsDto):
status="queued",
)

async def ainvoke(self, agent: Runnable, dto: ThreadRunsDto):
async def ainvoke(self, agent: AsyncContextManager[Runnable], dto: ThreadRunsDto):
input = self.thread_message_service.retreive_input(thread_id=dto.thread_id)

return await agent.ainvoke(
input={"messages": input},
)
async with agent as runnable:
return await runnable.ainvoke(
input={"messages": input},
)

def astream(
self, agent: Runnable, dto: ThreadRunsDto
async def astream(
self, agent: AsyncContextManager[Runnable], dto: ThreadRunsDto
) -> AsyncIterator[AssistantStreamEvent]:

input = self.thread_message_service.retreive_input(thread_id=dto.thread_id)

astream_events = agent.astream_events(
input={"messages": input},
version="v2",
)
async with agent as runnable:
astream_events = runnable.astream_events(
input={"messages": input},
version="v2",
)

return self.stream_adapter.to_openai_assistant_event_stream(
astream_events=astream_events, dto=dto
)
async for it in self.stream_adapter.to_openai_assistant_event_stream(
astream_events=astream_events, dto=dto
):
yield it
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import AsyncIterator
from starlette.responses import StreamingResponse
from openai.types.beta import AssistantStreamEvent
Expand Down Expand Up @@ -29,7 +28,7 @@ def to_streaming_response(
)

def __serialize_event(self, event: AssistantStreamEvent):
return self.__str_event(event.event, json.dumps(event.data.dict()))
return self.__str_event(event.event, event.data.model_dump_json())

def __str_event(self, event: str, data: str) -> str:
return f"event: {event}\ndata: {data}\n\n"
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_openai_api_bridge.assistant.adapter.openai_message_factory import (
create_message,
)
from pydantic import BaseModel
from .message_repository import (
MessageRepository,
)
Expand All @@ -27,7 +28,7 @@ def create(
content: Union[str, Iterable[MessageContentPartParam], None] = None,
status: Literal["in_progress", "incomplete", "completed"] = "completed",
run_id: Optional[str] = None,
metadata: Optional[object] = {},
metadata: Optional[dict[str, str] | BaseModel] = None,
) -> Message:
id = str(uuid.uuid4())
message = create_message(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
from typing import Literal, Optional
from openai.types.beta import Thread, ThreadDeleted
from openai.pagination import SyncCursorPage
from pydantic import BaseModel
from .thread_repository import ThreadRepository


class InMemoryThreadRepository(ThreadRepository):
def __init__(self, data: Optional[dict[str, Thread]] = None) -> None:
self.threads = data or {}

def create(self, metadata: Optional[object] = None) -> Thread:
def create(self, metadata: Optional[dict[str, str] | BaseModel] = None) -> Thread:
metadata = metadata.model_dump() if isinstance(metadata, BaseModel) else metadata
thread_id = str(uuid.uuid4())
thread = self.__create_thread(thread_id=thread_id, metadata=metadata)
self.threads[thread_id] = thread
Expand All @@ -20,10 +22,11 @@ def create(self, metadata: Optional[object] = None) -> Thread:
def update(
self,
thread_id: str,
metadata: Optional[object] = None,
metadata: Optional[dict[str, str] | BaseModel] = None,
) -> Thread:
if thread_id not in self.threads:
raise ValueError(f"Thread with id {thread_id} not found")
metadata = metadata.model_dump() if isinstance(metadata, BaseModel) else metadata

thread = self.threads[thread_id].copy(deep=True)
thread.metadata = metadata
Expand Down Expand Up @@ -62,7 +65,7 @@ def delete(
return self.__create_thread_deleted(thread_id=thread_id)

@staticmethod
def __create_thread(thread_id: str, metadata: Optional[object] = None) -> Thread:
def __create_thread(thread_id: str, metadata: Optional[dict[str, str]] = None) -> Thread:
return Thread(
id=thread_id,
object="thread",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from openai.types.beta import thread_create_params
from openai.pagination import SyncCursorPage
from openai.types.beta.threads.message import Message, Attachment
from pydantic import BaseModel


class MessageRepository(ABC):
Expand All @@ -23,7 +24,7 @@ def create(
assistant_id: Optional[str] = None,
attachments: Optional[List[Attachment]] = None,
run_id: Optional[str] = None,
metadata: Optional[dict] = {},
metadata: Optional[dict[str, str] | BaseModel] = None,
) -> Message:
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from openai.types.beta import Thread, ThreadDeleted
from openai.pagination import SyncCursorPage
from pydantic import BaseModel


class ThreadRepository(ABC):
Expand All @@ -14,7 +15,7 @@ def __init__(
@abstractmethod
def create(
self,
metadata: Optional[object] = None,
metadata: Optional[dict[str, str] | BaseModel] = None,
) -> Thread:
# client.beta.threads.create(messages)
pass
Expand All @@ -23,7 +24,7 @@ def create(
def update(
self,
thread_id: str,
metadata: Optional[object] = None,
metadata: Optional[dict[str, str] | BaseModel] = None,
) -> Thread:
# client.beta.threads.create(messages)
pass
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, List, Optional
from typing import AsyncIterator, List, Optional, AsyncContextManager
from langchain_core.runnables import Runnable
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langchain_openai_api_bridge.chat_completion.langchain_invoke_adapter import (
LangchainInvokeAdapter,
)
Expand All @@ -15,7 +15,7 @@ class ChatCompletionCompatibleAPI:

@staticmethod
def from_agent(
agent: Runnable,
agent: AsyncContextManager[Runnable],
llm_model: str,
system_fingerprint: Optional[str] = "",
event_adapter: callable = lambda event: None,
Expand All @@ -31,43 +31,45 @@ def __init__(
self,
stream_adapter: LangchainStreamAdapter,
invoke_adapter: LangchainInvokeAdapter,
agent: Runnable,
agent: AsyncContextManager[Runnable],
event_adapter: callable = lambda event: None,
) -> None:
self.stream_adapter = stream_adapter
self.invoke_adapter = invoke_adapter
self.agent = agent
self.event_adapter = event_adapter

def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]:
input = self.__to_input(messages)
astream_event = self.agent.astream_events(
input=input,
version="v2",
)
return ato_dict(
self.stream_adapter.ato_chat_completion_chunk_stream(astream_event, event_adapter=self.event_adapter)
)
async def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]:
async with self.agent as runnable:
input = self.__to_input(runnable, messages)
astream_event = runnable.astream_events(
input=input,
version="v2",
)
async for it in ato_dict(
self.stream_adapter.ato_chat_completion_chunk_stream(astream_event, event_adapter=self.event_adapter)
):
yield it

def invoke(self, messages: List[OpenAIChatMessage]) -> dict:
input = self.__to_input(messages)

result = self.agent.invoke(
input=input,
)
async def ainvoke(self, messages: List[OpenAIChatMessage]) -> dict:
async with self.agent as runnable:
input = self.__to_input(runnable, messages)
result = await runnable.ainvoke(
input=input,
)

return self.invoke_adapter.to_chat_completion_object(result).dict()
return self.invoke_adapter.to_chat_completion_object(result).model_dump()

def __to_input(self, messages: List[OpenAIChatMessage]):
if isinstance(self.agent, CompiledGraph):
def __to_input(self, runnable: Runnable, messages: List[OpenAIChatMessage]):
if isinstance(runnable, CompiledStateGraph):
return self.__to_react_agent_input(messages)
else:
return self.__to_chat_model_input(messages)

def __to_react_agent_input(self, messages: List[OpenAIChatMessage]):
return {
"messages": [message.dict() for message in messages],
"messages": [message.model_dump() for message in messages],
}

def __to_chat_model_input(self, messages: List[OpenAIChatMessage]):
return [message.dict() for message in messages]
return [message.model_dump() for message in messages]
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import List
from langchain_core.messages import BaseMessage
from langchain_anthropic import ChatAnthropic

from typing import ClassVar
from .anthropic_openai_compatible_chat_model_adapter import (
AnthropicOpenAICompatibleChatModelAdapter,
)


class AnthropicOpenAICompatibleChatModel(ChatAnthropic):

adapter = AnthropicOpenAICompatibleChatModelAdapter()
adapter: ClassVar[AnthropicOpenAICompatibleChatModelAdapter] = AnthropicOpenAICompatibleChatModelAdapter()

def _stream(self, messages: List[List[BaseMessage]], **kwargs):
transformed_messages = self.adapter.to_openai_format_messages(messages)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from typing import List
from typing import List, ClassVar
from langchain_core.messages import BaseMessage
from langchain_llamacpp_chat_model import LlamaChatModel

from langchain_community.chat_models import ChatLlamaCpp
from langchain_openai_api_bridge.chat_model_adapter.llamacpp.llamacpp_openai_compatible_chat_model_adapter import (
LlamacppOpenAICompatibleChatModelAdapter,
)


class LLamacppOpenAICompatibleChatModel(LlamaChatModel):
class LLamacppOpenAICompatibleChatModel(ChatLlamaCpp):

adapter = LlamacppOpenAICompatibleChatModelAdapter()
adapter: ClassVar[LlamacppOpenAICompatibleChatModelAdapter] = LlamacppOpenAICompatibleChatModelAdapter()

def _stream(self, messages: List[List[BaseMessage]], **kwargs):
transformed_messages = self.adapter.to_openai_format_messages(messages)
Expand Down
Loading