diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index 737f43bc..c07005f1 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -151,6 +151,20 @@ def __init__(self, loop: asyncio.AbstractEventLoop): self._closing = False self.pipeline_factory = PipelineFactory(SecretsManager()) self.context_tracking: Optional[PipelineContext] = None + self.idle_timeout = 10 + self.idle_timer = None + + def _reset_idle_timer(self) -> None: + if self.idle_timer: + self.idle_timer.cancel() + self.idle_timer = asyncio.get_event_loop().call_later( + self.idle_timeout, self._handle_idle_timeout + ) + + def _handle_idle_timeout(self) -> None: + logger.warning("Idle timeout reached, closing connection") + if self.transport and not self.transport.is_closing(): + self.transport.close() def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]: if method == "POST" and path == "v1/engines/copilot-codex/completions": @@ -215,6 +229,7 @@ def connection_made(self, transport: asyncio.Transport) -> None: self.transport = transport self.peername = transport.get_extra_info("peername") logger.debug(f"Client connected from {self.peername}") + self._reset_idle_timer() def get_headers_dict(self) -> Dict[str, str]: """Convert raw headers to dictionary format""" @@ -350,8 +365,10 @@ async def _forward_data_to_target(self, data: bytes) -> None: pipeline_output = pipeline_output.reconstruct() self.target_transport.write(pipeline_output) + def data_received(self, data: bytes) -> None: """Handle received data from client""" + self._reset_idle_timer() try: if not self._check_buffer_size(data): self.send_error_response(413, b"Request body too large") @@ -556,6 +573,7 @@ async def connect_to_target(self) -> None: logger.error(f"Error during TLS handshake: {e}") self.send_error_response(502, b"TLS handshake failed") + def send_error_response(self, status: int, message: bytes) -> None: """Send error response to client""" if self._closing: @@ -593,6 +611,37 @@ def connection_lost(self, exc: Optional[Exception]) -> None: self.buffer.clear() self.ssl_context = None + if self.idle_timer: + self.idle_timer.cancel() + + def eof_received(self) -> None: + print("in eof received") + """Handle connection loss""" + if self._closing: + return + + self._closing = True + logger.debug(f"EOF received from {self.peername}") + + # Close target transport if it exists and isn't already closing + if self.target_transport and not self.target_transport.is_closing(): + try: + self.target_transport.close() + except Exception as e: + logger.error(f"Error closing target transport when EOF: {e}") + + # Clear references to help with cleanup + self.transport = None + self.target_transport = None + self.buffer.clear() + self.ssl_context = None + + def pause_writing(self) -> None: + print("Transport buffer full, pausing writing") + + def resume_writing(self) -> None: + print("Transport buffer ready, resuming writing") + @classmethod async def create_proxy_server( cls, host: str, port: int, ssl_context: Optional[ssl.SSLContext] = None