|
5 | 5 | # |
6 | 6 |
|
7 | 7 | import argparse |
8 | | -import asyncio |
9 | 8 | import sys |
| 9 | +import uuid |
10 | 10 | from contextlib import asynccontextmanager |
11 | | -from typing import Dict |
| 11 | +from http import HTTPMethod |
| 12 | +from typing import Any, Dict, List, Optional, TypedDict |
12 | 13 |
|
13 | 14 | import uvicorn |
14 | 15 | from bot import run_bot |
15 | 16 | from dotenv import load_dotenv |
16 | | -from fastapi import BackgroundTasks, FastAPI |
| 17 | +from fastapi import BackgroundTasks, FastAPI, Request, Response |
17 | 18 | from fastapi.responses import RedirectResponse |
18 | 19 | from loguru import logger |
19 | | -from pipecat.transports.smallwebrtc.connection import IceServer, SmallWebRTCConnection |
| 20 | +from pipecat.transports.smallwebrtc.connection import IceServer |
| 21 | +from pipecat.transports.smallwebrtc.request_handler import ( |
| 22 | + IceCandidate, |
| 23 | + SmallWebRTCPatchRequest, |
| 24 | + SmallWebRTCRequest, |
| 25 | + SmallWebRTCRequestHandler, |
| 26 | +) |
20 | 27 | from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI |
21 | 28 |
|
22 | 29 | # Load environment variables |
23 | 30 | load_dotenv(override=True) |
24 | 31 |
|
25 | 32 | app = FastAPI() |
26 | 33 |
|
27 | | -# Store connections by pc_id |
28 | | -pcs_map: Dict[str, SmallWebRTCConnection] = {} |
29 | | - |
30 | | -ice_servers = [ |
31 | | - IceServer( |
32 | | - urls="stun:stun.l.google.com:19302", |
33 | | - ) |
34 | | -] |
35 | | - |
36 | 34 | # Mount the frontend at / |
37 | 35 | app.mount("/prebuilt", SmallWebRTCPrebuiltUI) |
38 | 36 |
|
| 37 | +# Initialize the SmallWebRTC request handler |
| 38 | +small_webrtc_handler: SmallWebRTCRequestHandler = SmallWebRTCRequestHandler() |
| 39 | + |
| 40 | +# In-memory store of active sessions: session_id -> session info |
| 41 | +active_sessions: Dict[str, Dict[str, Any]] = {} |
| 42 | + |
39 | 43 |
|
40 | 44 | @app.get("/", include_in_schema=False) |
41 | 45 | async def root_redirect(): |
42 | 46 | return RedirectResponse(url="/prebuilt/") |
43 | 47 |
|
44 | 48 |
|
45 | 49 | @app.post("/api/offer") |
46 | | -async def offer(request: dict, background_tasks: BackgroundTasks): |
47 | | - pc_id = request.get("pc_id") |
48 | | - |
49 | | - if pc_id and pc_id in pcs_map: |
50 | | - pipecat_connection = pcs_map[pc_id] |
51 | | - logger.info(f"Reusing existing connection for pc_id: {pc_id}") |
52 | | - await pipecat_connection.renegotiate( |
53 | | - sdp=request["sdp"], type=request["type"], restart_pc=request.get("restart_pc", False) |
54 | | - ) |
55 | | - else: |
56 | | - pipecat_connection = SmallWebRTCConnection(ice_servers) |
57 | | - await pipecat_connection.initialize(sdp=request["sdp"], type=request["type"]) |
| 50 | +async def offer(request: SmallWebRTCRequest, background_tasks: BackgroundTasks): |
| 51 | + """Handle WebRTC offer requests via SmallWebRTCRequestHandler.""" |
58 | 52 |
|
59 | | - @pipecat_connection.event_handler("closed") |
60 | | - async def handle_disconnected(webrtc_connection: SmallWebRTCConnection): |
61 | | - logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}") |
62 | | - pcs_map.pop(webrtc_connection.pc_id, None) |
| 53 | + # Prepare runner arguments with the callback to run your bot |
| 54 | + async def webrtc_connection_callback(connection): |
| 55 | + background_tasks.add_task(run_bot, connection) |
63 | 56 |
|
64 | | - background_tasks.add_task(run_bot, pipecat_connection) |
| 57 | + # Delegate handling to SmallWebRTCRequestHandler |
| 58 | + answer = await small_webrtc_handler.handle_web_request( |
| 59 | + request=request, |
| 60 | + webrtc_connection_callback=webrtc_connection_callback, |
| 61 | + ) |
| 62 | + return answer |
65 | 63 |
|
66 | | - answer = pipecat_connection.get_answer() |
67 | | - # Updating the peer connection inside the map |
68 | | - pcs_map[answer["pc_id"]] = pipecat_connection |
69 | 64 |
|
70 | | - return answer |
| 65 | +@app.patch("/api/offer") |
| 66 | +async def ice_candidate(request: SmallWebRTCPatchRequest): |
| 67 | + """Handle WebRTC new ice candidate requests.""" |
| 68 | + logger.debug(f"Received patch request: {request}") |
| 69 | + await small_webrtc_handler.handle_patch_request(request) |
| 70 | + return {"status": "success"} |
| 71 | + |
| 72 | + |
| 73 | +@app.post("/start") |
| 74 | +async def rtvi_start(request: Request): |
| 75 | + """Mimic Pipecat Cloud's /start endpoint.""" |
| 76 | + |
| 77 | + class IceConfig(TypedDict): |
| 78 | + iceServers: List[IceServer] |
| 79 | + |
| 80 | + class StartBotResult(TypedDict, total=False): |
| 81 | + sessionId: str |
| 82 | + iceConfig: Optional[IceConfig] |
| 83 | + |
| 84 | + # Parse the request body |
| 85 | + try: |
| 86 | + request_data = await request.json() |
| 87 | + logger.debug(f"Received request: {request_data}") |
| 88 | + except Exception as e: |
| 89 | + logger.error(f"Failed to parse request body: {e}") |
| 90 | + request_data = {} |
| 91 | + |
| 92 | + # Store session info immediately in memory, replicate the behavior expected on Pipecat Cloud |
| 93 | + session_id = str(uuid.uuid4()) |
| 94 | + active_sessions[session_id] = request_data |
| 95 | + |
| 96 | + result: StartBotResult = {"sessionId": session_id} |
| 97 | + if request_data.get("enableDefaultIceServers"): |
| 98 | + result["iceConfig"] = IceConfig( |
| 99 | + iceServers=[IceServer(urls=["stun:stun.l.google.com:19302"])] |
| 100 | + ) |
| 101 | + |
| 102 | + return result |
| 103 | + |
| 104 | + |
| 105 | +@app.api_route( |
| 106 | + "/sessions/{session_id}/{path:path}", |
| 107 | + methods=["GET", "POST", "PUT", "PATCH", "DELETE"], |
| 108 | +) |
| 109 | +async def proxy_request( |
| 110 | + session_id: str, path: str, request: Request, background_tasks: BackgroundTasks |
| 111 | +): |
| 112 | + """Mimic Pipecat Cloud's proxy.""" |
| 113 | + active_session = active_sessions.get(session_id) |
| 114 | + if active_session is None: |
| 115 | + return Response(content="Invalid or not-yet-ready session_id", status_code=404) |
| 116 | + |
| 117 | + if path.endswith("api/offer"): |
| 118 | + # Parse the request body and convert to SmallWebRTCRequest |
| 119 | + try: |
| 120 | + request_data = await request.json() |
| 121 | + if request.method == HTTPMethod.POST.value: |
| 122 | + webrtc_request = SmallWebRTCRequest( |
| 123 | + sdp=request_data["sdp"], |
| 124 | + type=request_data["type"], |
| 125 | + pc_id=request_data.get("pc_id"), |
| 126 | + restart_pc=request_data.get("restart_pc"), |
| 127 | + request_data=request_data, |
| 128 | + ) |
| 129 | + return await offer(webrtc_request, background_tasks) |
| 130 | + elif request.method == HTTPMethod.PATCH.value: |
| 131 | + patch_request = SmallWebRTCPatchRequest( |
| 132 | + pc_id=request_data["pc_id"], |
| 133 | + candidates=[IceCandidate(**c) for c in request_data.get("candidates", [])], |
| 134 | + ) |
| 135 | + return await ice_candidate(patch_request) |
| 136 | + except Exception as e: |
| 137 | + logger.error(f"Failed to parse WebRTC request: {e}") |
| 138 | + return Response(content="Invalid WebRTC request", status_code=400) |
| 139 | + |
| 140 | + logger.info(f"Received request for path: {path}") |
| 141 | + return Response(status_code=200) |
71 | 142 |
|
72 | 143 |
|
73 | 144 | @asynccontextmanager |
74 | 145 | async def lifespan(app: FastAPI): |
75 | 146 | yield # Run app |
76 | | - coros = [pc.disconnect() for pc in pcs_map.values()] |
77 | | - await asyncio.gather(*coros) |
78 | | - pcs_map.clear() |
| 147 | + await small_webrtc_handler.close() |
79 | 148 |
|
80 | 149 |
|
81 | 150 | if __name__ == "__main__": |
|
0 commit comments