Skip to content

Commit c95c0fd

Browse files
authored
feat: support non react-agent for chat completion api (#33)
1 parent 5e1d1b7 commit c95c0fd

File tree

7 files changed

+166
-36
lines changed

7 files changed

+166
-36
lines changed

langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import AsyncIterator, List, Optional
22
from langchain_core.runnables import Runnable
3+
from langgraph.graph.graph import CompiledGraph
34
from langchain_openai_api_bridge.chat_completion.langchain_invoke_adapter import (
45
LangchainInvokeAdapter,
56
)
@@ -43,13 +44,24 @@ def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]:
4344
)
4445

4546
def invoke(self, messages: List[OpenAIChatMessage]) -> dict:
47+
input = self.__to_input(messages)
48+
4649
result = self.agent.invoke(
47-
input=self.__to_input(messages),
50+
input=input,
4851
)
4952

5053
return self.invoke_adapter.to_chat_completion_object(result).dict()
5154

5255
def __to_input(self, messages: List[OpenAIChatMessage]):
56+
if isinstance(self.agent, CompiledGraph):
57+
return self.__to_react_agent_input(messages)
58+
else:
59+
return self.__to_chat_model_input(messages)
60+
61+
def __to_react_agent_input(self, messages: List[OpenAIChatMessage]):
5362
return {
5463
"messages": [message.dict() for message in messages],
5564
}
65+
66+
def __to_chat_model_input(self, messages: List[OpenAIChatMessage]):
67+
return [message.dict() for message in messages]

langchain_openai_api_bridge/chat_completion/langchain_invoke_adapter.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
OpenAIChatCompletionObject,
1111
OpenAIChatMessage,
1212
)
13+
from langchain_core.messages import AIMessage
1314

1415

1516
class LangchainInvokeAdapter:
@@ -18,20 +19,38 @@ def __init__(self, llm_model: str, system_fingerprint: str = ""):
1819
self.system_fingerprint = system_fingerprint
1920

2021
def to_chat_completion_object(self, invoke_result) -> OpenAIChatCompletionObject:
21-
last_message = invoke_result["messages"][-1]
22+
message = self.__create_openai_chat_message(invoke_result)
23+
id = self.__get_id(invoke_result)
2224

2325
return ChatCompletionObjectFactory.create(
24-
id=last_message.id,
26+
id=id,
2527
model=self.llm_model,
2628
system_fingerprint=self.system_fingerprint,
2729
choices=[
2830
OpenAIChatCompletionChoice(
2931
index=0,
30-
message=OpenAIChatMessage(
31-
role=to_openai_role(last_message.type),
32-
content=to_string_content(content=last_message.content),
33-
),
32+
message=message,
3433
finish_reason="stop",
3534
)
3635
],
3736
)
37+
38+
def __get_id(self, invoke_result):
39+
if isinstance(invoke_result, AIMessage):
40+
return invoke_result.id
41+
42+
last_message = invoke_result["messages"][-1]
43+
return last_message.id
44+
45+
def __create_openai_chat_message(self, invoke_result) -> OpenAIChatMessage:
46+
if isinstance(invoke_result, AIMessage):
47+
return OpenAIChatMessage(
48+
role=to_openai_role(invoke_result.type),
49+
content=to_string_content(content=invoke_result.content),
50+
)
51+
52+
last_message = invoke_result["messages"][-1]
53+
return OpenAIChatMessage(
54+
role=to_openai_role(last_message.type),
55+
content=to_string_content(content=last_message.content),
56+
)
Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
from langchain_openai_api_bridge.core.base_agent_factory import BaseAgentFactory
22
from langchain_core.runnables import Runnable
3-
from langchain_core.tools import tool
4-
from langgraph.prebuilt import create_react_agent
53
from langchain_openai import ChatOpenAI
64

75
from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto
86

97

10-
@tool
11-
def magic_number_tool(input: int) -> int:
12-
"""Applies a magic function to an input."""
13-
return input + 2
14-
15-
168
class MyOpenAIAgentFactory(BaseAgentFactory):
179

1810
def create_agent(self, dto: CreateAgentDto) -> Runnable:
@@ -23,8 +15,4 @@ def create_agent(self, dto: CreateAgentDto) -> Runnable:
2315
temperature=dto.temperature,
2416
)
2517

26-
return create_react_agent(
27-
llm,
28-
[magic_number_tool],
29-
messages_modifier="""You are a helpful assistant.""",
30-
)
18+
return llm

tests/test_functional/fastapi_chat_completion_openai/test_server_openai.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def openai_client():
1717

1818
def test_chat_completion_invoke(openai_client):
1919
chat_completion = openai_client.chat.completions.create(
20-
model="gpt-3.5-turbo",
20+
model="gpt-4o-mini",
2121
messages=[
2222
{
2323
"role": "user",
@@ -30,7 +30,7 @@ def test_chat_completion_invoke(openai_client):
3030

3131
def test_chat_completion_stream(openai_client):
3232
chunks = openai_client.chat.completions.create(
33-
model="gpt-3.5-turbo",
33+
model="gpt-4o-mini",
3434
messages=[{"role": "user", "content": 'Say "This is a test"'}],
3535
stream=True,
3636
)
@@ -42,17 +42,3 @@ def test_chat_completion_stream(openai_client):
4242
stream_output = "".join(every_content)
4343

4444
assert "This is a test" in stream_output
45-
46-
47-
def test_tool(openai_client):
48-
49-
chat_completion = openai_client.chat.completions.create(
50-
model="gpt-3.5-turbo",
51-
messages=[
52-
{
53-
"role": "user",
54-
"content": 'Say "Magic number of 2"',
55-
}
56-
],
57-
)
58-
assert "4" in chat_completion.choices[0].message.content
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from langchain_openai_api_bridge.core.base_agent_factory import BaseAgentFactory
2+
from langchain_core.runnables import Runnable
3+
from langchain_core.tools import tool
4+
from langgraph.prebuilt import create_react_agent
5+
from langchain_openai import ChatOpenAI
6+
7+
from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto
8+
9+
10+
@tool
11+
def magic_number_tool(input: int) -> int:
12+
"""Applies a magic function to an input."""
13+
return input + 2
14+
15+
16+
class MyOpenAIReactAgentFactory(BaseAgentFactory):
17+
18+
def create_agent(self, dto: CreateAgentDto) -> Runnable:
19+
llm = ChatOpenAI(
20+
model=dto.model,
21+
api_key=dto.api_key,
22+
streaming=True,
23+
temperature=dto.temperature,
24+
)
25+
26+
return create_react_agent(
27+
llm,
28+
[magic_number_tool],
29+
messages_modifier="""You are a helpful assistant.""",
30+
)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from fastapi import FastAPI
2+
from fastapi.middleware.cors import CORSMiddleware
3+
from dotenv import load_dotenv, find_dotenv
4+
import uvicorn
5+
6+
from langchain_openai_api_bridge.fastapi.langchain_openai_api_bridge_fastapi import (
7+
LangchainOpenaiApiBridgeFastAPI,
8+
)
9+
from tests.test_functional.fastapi_chat_completion_react_agent_openai.my_openai_react_agent_factory import (
10+
MyOpenAIReactAgentFactory,
11+
)
12+
13+
_ = load_dotenv(find_dotenv())
14+
15+
16+
app = FastAPI(
17+
title="Langchain Agent OpenAI API Bridge",
18+
version="1.0",
19+
description="OpenAI API exposing langchain agent",
20+
)
21+
22+
app.add_middleware(
23+
CORSMiddleware,
24+
allow_origins=["*"],
25+
allow_credentials=True,
26+
allow_methods=["*"],
27+
allow_headers=["*"],
28+
expose_headers=["*"],
29+
)
30+
31+
bridge = LangchainOpenaiApiBridgeFastAPI(
32+
app=app, agent_factory_provider=lambda: MyOpenAIReactAgentFactory()
33+
)
34+
bridge.bind_openai_chat_completion()
35+
36+
if __name__ == "__main__":
37+
uvicorn.run(app, host="localhost")
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import pytest
2+
from openai import OpenAI
3+
from fastapi.testclient import TestClient
4+
from react_agent_server_openai import app
5+
6+
7+
test_api = TestClient(app)
8+
9+
10+
@pytest.fixture
11+
def openai_client():
12+
return OpenAI(
13+
base_url="http://testserver/openai/v1",
14+
http_client=test_api,
15+
)
16+
17+
18+
def test_chat_completion_invoke(openai_client):
19+
chat_completion = openai_client.chat.completions.create(
20+
model="gpt-3.5-turbo",
21+
messages=[
22+
{
23+
"role": "user",
24+
"content": 'Say "This is a test"',
25+
}
26+
],
27+
)
28+
assert "This is a test" in chat_completion.choices[0].message.content
29+
30+
31+
def test_chat_completion_stream(openai_client):
32+
chunks = openai_client.chat.completions.create(
33+
model="gpt-3.5-turbo",
34+
messages=[{"role": "user", "content": 'Say "This is a test"'}],
35+
stream=True,
36+
)
37+
every_content = []
38+
for chunk in chunks:
39+
if chunk.choices and isinstance(chunk.choices[0].delta.content, str):
40+
every_content.append(chunk.choices[0].delta.content)
41+
42+
stream_output = "".join(every_content)
43+
44+
assert "This is a test" in stream_output
45+
46+
47+
def test_tool(openai_client):
48+
49+
chat_completion = openai_client.chat.completions.create(
50+
model="gpt-3.5-turbo",
51+
messages=[
52+
{
53+
"role": "user",
54+
"content": 'Say "Magic number of 2"',
55+
}
56+
],
57+
)
58+
assert "4" in chat_completion.choices[0].message.content

0 commit comments

Comments
 (0)