Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 106 additions & 37 deletions p2p-webrtc/video-transform/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,77 +5,146 @@
#

import argparse
import asyncio
import sys
import uuid
from contextlib import asynccontextmanager
from typing import Dict
from http import HTTPMethod
from typing import Any, Dict, List, Optional, TypedDict

import uvicorn
from bot import run_bot
from dotenv import load_dotenv
from fastapi import BackgroundTasks, FastAPI
from fastapi import BackgroundTasks, FastAPI, Request, Response
from fastapi.responses import RedirectResponse
from loguru import logger
from pipecat.transports.smallwebrtc.connection import IceServer, SmallWebRTCConnection
from pipecat.transports.smallwebrtc.connection import IceServer
from pipecat.transports.smallwebrtc.request_handler import (
IceCandidate,
SmallWebRTCPatchRequest,
SmallWebRTCRequest,
SmallWebRTCRequestHandler,
)
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI

# Load environment variables
load_dotenv(override=True)

app = FastAPI()

# Store connections by pc_id
pcs_map: Dict[str, SmallWebRTCConnection] = {}

ice_servers = [
IceServer(
urls="stun:stun.l.google.com:19302",
)
]

# Mount the frontend at /
app.mount("/prebuilt", SmallWebRTCPrebuiltUI)

# Initialize the SmallWebRTC request handler
small_webrtc_handler: SmallWebRTCRequestHandler = SmallWebRTCRequestHandler()

# In-memory store of active sessions: session_id -> session info
active_sessions: Dict[str, Dict[str, Any]] = {}


@app.get("/", include_in_schema=False)
async def root_redirect():
return RedirectResponse(url="/prebuilt/")


@app.post("/api/offer")
async def offer(request: dict, background_tasks: BackgroundTasks):
pc_id = request.get("pc_id")

if pc_id and pc_id in pcs_map:
pipecat_connection = pcs_map[pc_id]
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
await pipecat_connection.renegotiate(
sdp=request["sdp"], type=request["type"], restart_pc=request.get("restart_pc", False)
)
else:
pipecat_connection = SmallWebRTCConnection(ice_servers)
await pipecat_connection.initialize(sdp=request["sdp"], type=request["type"])
async def offer(request: SmallWebRTCRequest, background_tasks: BackgroundTasks):
"""Handle WebRTC offer requests via SmallWebRTCRequestHandler."""

@pipecat_connection.event_handler("closed")
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
pcs_map.pop(webrtc_connection.pc_id, None)
# Prepare runner arguments with the callback to run your bot
async def webrtc_connection_callback(connection):
background_tasks.add_task(run_bot, connection)

background_tasks.add_task(run_bot, pipecat_connection)
# Delegate handling to SmallWebRTCRequestHandler
answer = await small_webrtc_handler.handle_web_request(
request=request,
webrtc_connection_callback=webrtc_connection_callback,
)
return answer

answer = pipecat_connection.get_answer()
# Updating the peer connection inside the map
pcs_map[answer["pc_id"]] = pipecat_connection

return answer
@app.patch("/api/offer")
async def ice_candidate(request: SmallWebRTCPatchRequest):
"""Handle WebRTC new ice candidate requests."""
logger.debug(f"Received patch request: {request}")
await small_webrtc_handler.handle_patch_request(request)
return {"status": "success"}


@app.post("/start")
async def rtvi_start(request: Request):
"""Mimic Pipecat Cloud's /start endpoint."""

class IceConfig(TypedDict):
iceServers: List[IceServer]

class StartBotResult(TypedDict, total=False):
sessionId: str
iceConfig: Optional[IceConfig]

# Parse the request body
try:
request_data = await request.json()
logger.debug(f"Received request: {request_data}")
except Exception as e:
logger.error(f"Failed to parse request body: {e}")
request_data = {}

# Store session info immediately in memory, replicate the behavior expected on Pipecat Cloud
session_id = str(uuid.uuid4())
active_sessions[session_id] = request_data

result: StartBotResult = {"sessionId": session_id}
if request_data.get("enableDefaultIceServers"):
result["iceConfig"] = IceConfig(
iceServers=[IceServer(urls=["stun:stun.l.google.com:19302"])]
)

return result


@app.api_route(
"/sessions/{session_id}/{path:path}",
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
)
async def proxy_request(
session_id: str, path: str, request: Request, background_tasks: BackgroundTasks
):
"""Mimic Pipecat Cloud's proxy."""
active_session = active_sessions.get(session_id)
if active_session is None:
return Response(content="Invalid or not-yet-ready session_id", status_code=404)

if path.endswith("api/offer"):
# Parse the request body and convert to SmallWebRTCRequest
try:
request_data = await request.json()
if request.method == HTTPMethod.POST.value:
webrtc_request = SmallWebRTCRequest(
sdp=request_data["sdp"],
type=request_data["type"],
pc_id=request_data.get("pc_id"),
restart_pc=request_data.get("restart_pc"),
request_data=request_data,
)
return await offer(webrtc_request, background_tasks)
elif request.method == HTTPMethod.PATCH.value:
patch_request = SmallWebRTCPatchRequest(
pc_id=request_data["pc_id"],
candidates=[IceCandidate(**c) for c in request_data.get("candidates", [])],
)
return await ice_candidate(patch_request)
except Exception as e:
logger.error(f"Failed to parse WebRTC request: {e}")
return Response(content="Invalid WebRTC request", status_code=400)

logger.info(f"Received request for path: {path}")
return Response(status_code=200)


@asynccontextmanager
async def lifespan(app: FastAPI):
yield # Run app
coros = [pc.disconnect() for pc in pcs_map.values()]
await asyncio.gather(*coros)
pcs_map.clear()
await small_webrtc_handler.close()


if __name__ == "__main__":
Expand Down