Skip to content

Commit cf646ff

Browse files
authored
Merge pull request #104 from pipecat-ai/filipi/video_transform_start_endpoint
Refactoring p2p video_transform example
2 parents 38c63a0 + 270bb5b commit cf646ff

File tree

1 file changed

+106
-37
lines changed

1 file changed

+106
-37
lines changed

p2p-webrtc/video-transform/server/server.py

Lines changed: 106 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,77 +5,146 @@
55
#
66

77
import argparse
8-
import asyncio
98
import sys
9+
import uuid
1010
from contextlib import asynccontextmanager
11-
from typing import Dict
11+
from http import HTTPMethod
12+
from typing import Any, Dict, List, Optional, TypedDict
1213

1314
import uvicorn
1415
from bot import run_bot
1516
from dotenv import load_dotenv
16-
from fastapi import BackgroundTasks, FastAPI
17+
from fastapi import BackgroundTasks, FastAPI, Request, Response
1718
from fastapi.responses import RedirectResponse
1819
from loguru import logger
19-
from pipecat.transports.smallwebrtc.connection import IceServer, SmallWebRTCConnection
20+
from pipecat.transports.smallwebrtc.connection import IceServer
21+
from pipecat.transports.smallwebrtc.request_handler import (
22+
IceCandidate,
23+
SmallWebRTCPatchRequest,
24+
SmallWebRTCRequest,
25+
SmallWebRTCRequestHandler,
26+
)
2027
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
2128

2229
# Load environment variables
2330
load_dotenv(override=True)
2431

2532
app = FastAPI()
2633

27-
# Store connections by pc_id
28-
pcs_map: Dict[str, SmallWebRTCConnection] = {}
29-
30-
ice_servers = [
31-
IceServer(
32-
urls="stun:stun.l.google.com:19302",
33-
)
34-
]
35-
3634
# Mount the frontend at /
3735
app.mount("/prebuilt", SmallWebRTCPrebuiltUI)
3836

37+
# Initialize the SmallWebRTC request handler
38+
small_webrtc_handler: SmallWebRTCRequestHandler = SmallWebRTCRequestHandler()
39+
40+
# In-memory store of active sessions: session_id -> session info
41+
active_sessions: Dict[str, Dict[str, Any]] = {}
42+
3943

4044
@app.get("/", include_in_schema=False)
4145
async def root_redirect():
4246
return RedirectResponse(url="/prebuilt/")
4347

4448

4549
@app.post("/api/offer")
46-
async def offer(request: dict, background_tasks: BackgroundTasks):
47-
pc_id = request.get("pc_id")
48-
49-
if pc_id and pc_id in pcs_map:
50-
pipecat_connection = pcs_map[pc_id]
51-
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
52-
await pipecat_connection.renegotiate(
53-
sdp=request["sdp"], type=request["type"], restart_pc=request.get("restart_pc", False)
54-
)
55-
else:
56-
pipecat_connection = SmallWebRTCConnection(ice_servers)
57-
await pipecat_connection.initialize(sdp=request["sdp"], type=request["type"])
50+
async def offer(request: SmallWebRTCRequest, background_tasks: BackgroundTasks):
51+
"""Handle WebRTC offer requests via SmallWebRTCRequestHandler."""
5852

59-
@pipecat_connection.event_handler("closed")
60-
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
61-
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
62-
pcs_map.pop(webrtc_connection.pc_id, None)
53+
# Prepare runner arguments with the callback to run your bot
54+
async def webrtc_connection_callback(connection):
55+
background_tasks.add_task(run_bot, connection)
6356

64-
background_tasks.add_task(run_bot, pipecat_connection)
57+
# Delegate handling to SmallWebRTCRequestHandler
58+
answer = await small_webrtc_handler.handle_web_request(
59+
request=request,
60+
webrtc_connection_callback=webrtc_connection_callback,
61+
)
62+
return answer
6563

66-
answer = pipecat_connection.get_answer()
67-
# Updating the peer connection inside the map
68-
pcs_map[answer["pc_id"]] = pipecat_connection
6964

70-
return answer
65+
@app.patch("/api/offer")
66+
async def ice_candidate(request: SmallWebRTCPatchRequest):
67+
"""Handle WebRTC new ice candidate requests."""
68+
logger.debug(f"Received patch request: {request}")
69+
await small_webrtc_handler.handle_patch_request(request)
70+
return {"status": "success"}
71+
72+
73+
@app.post("/start")
74+
async def rtvi_start(request: Request):
75+
"""Mimic Pipecat Cloud's /start endpoint."""
76+
77+
class IceConfig(TypedDict):
78+
iceServers: List[IceServer]
79+
80+
class StartBotResult(TypedDict, total=False):
81+
sessionId: str
82+
iceConfig: Optional[IceConfig]
83+
84+
# Parse the request body
85+
try:
86+
request_data = await request.json()
87+
logger.debug(f"Received request: {request_data}")
88+
except Exception as e:
89+
logger.error(f"Failed to parse request body: {e}")
90+
request_data = {}
91+
92+
# Store session info immediately in memory, replicate the behavior expected on Pipecat Cloud
93+
session_id = str(uuid.uuid4())
94+
active_sessions[session_id] = request_data
95+
96+
result: StartBotResult = {"sessionId": session_id}
97+
if request_data.get("enableDefaultIceServers"):
98+
result["iceConfig"] = IceConfig(
99+
iceServers=[IceServer(urls=["stun:stun.l.google.com:19302"])]
100+
)
101+
102+
return result
103+
104+
105+
@app.api_route(
106+
"/sessions/{session_id}/{path:path}",
107+
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
108+
)
109+
async def proxy_request(
110+
session_id: str, path: str, request: Request, background_tasks: BackgroundTasks
111+
):
112+
"""Mimic Pipecat Cloud's proxy."""
113+
active_session = active_sessions.get(session_id)
114+
if active_session is None:
115+
return Response(content="Invalid or not-yet-ready session_id", status_code=404)
116+
117+
if path.endswith("api/offer"):
118+
# Parse the request body and convert to SmallWebRTCRequest
119+
try:
120+
request_data = await request.json()
121+
if request.method == HTTPMethod.POST.value:
122+
webrtc_request = SmallWebRTCRequest(
123+
sdp=request_data["sdp"],
124+
type=request_data["type"],
125+
pc_id=request_data.get("pc_id"),
126+
restart_pc=request_data.get("restart_pc"),
127+
request_data=request_data,
128+
)
129+
return await offer(webrtc_request, background_tasks)
130+
elif request.method == HTTPMethod.PATCH.value:
131+
patch_request = SmallWebRTCPatchRequest(
132+
pc_id=request_data["pc_id"],
133+
candidates=[IceCandidate(**c) for c in request_data.get("candidates", [])],
134+
)
135+
return await ice_candidate(patch_request)
136+
except Exception as e:
137+
logger.error(f"Failed to parse WebRTC request: {e}")
138+
return Response(content="Invalid WebRTC request", status_code=400)
139+
140+
logger.info(f"Received request for path: {path}")
141+
return Response(status_code=200)
71142

72143

73144
@asynccontextmanager
74145
async def lifespan(app: FastAPI):
75146
yield # Run app
76-
coros = [pc.disconnect() for pc in pcs_map.values()]
77-
await asyncio.gather(*coros)
78-
pcs_map.clear()
147+
await small_webrtc_handler.close()
79148

80149

81150
if __name__ == "__main__":

0 commit comments

Comments
 (0)