11import asyncio
22import json
33from pathlib import Path
4- from typing import List , Optional
4+ from typing import List , Optional , Type
55
66import structlog
77from pydantic import BaseModel
88from sqlalchemy import TextClause , text
99from sqlalchemy .ext .asyncio import create_async_engine
1010
1111from codegate .db .fim_cache import FimCache
12- from codegate .db .models import Alert , Output , Prompt
13- from codegate .db .queries import (
14- AsyncQuerier ,
12+ from codegate .db .models import (
13+ Alert ,
1514 GetAlertsWithPromptAndOutputRow ,
1615 GetPromptWithOutputsRow ,
16+ Output ,
17+ Prompt ,
1718)
1819from codegate .pipeline .base import PipelineContext
1920
@@ -83,11 +84,9 @@ async def init_db(self):
8384 await self ._async_db_engine .dispose ()
8485
8586 async def _execute_update_pydantic_model (
86- self , model : BaseModel , sql_command : TextClause #
87+ self , model : BaseModel , sql_command : TextClause
8788 ) -> Optional [BaseModel ]:
88- # There are create method in queries.py automatically generated by sqlc
89- # However, the methods are buggy for Pydancti and don't work as expected.
90- # Manually writing the SQL query to insert Pydantic models.
89+ """Execute an update or insert command for a Pydantic model."""
9190 async with self ._async_db_engine .begin () as conn :
9291 try :
9392 result = await conn .execute (sql_command , model .model_dump ())
@@ -117,8 +116,9 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option
117116 # logger.debug(f"Recorded request: {recorded_request}")
118117 return recorded_request # type: ignore
119118
120- async def update_request (self , initial_id : str ,
121- prompt_params : Optional [Prompt ] = None ) -> Optional [Prompt ]:
119+ async def update_request (
120+ self , initial_id : str , prompt_params : Optional [Prompt ] = None
121+ ) -> Optional [Prompt ]:
122122 if prompt_params is None :
123123 return None
124124 prompt_params .id = initial_id # overwrite the initial id of the request
@@ -135,8 +135,9 @@ async def update_request(self, initial_id: str,
135135 # logger.debug(f"Recorded request: {recorded_request}")
136136 return updated_request # type: ignore
137137
138- async def record_outputs (self , outputs : List [Output ],
139- initial_id : Optional [str ]) -> Optional [Output ]:
138+ async def record_outputs (
139+ self , outputs : List [Output ], initial_id : Optional [str ]
140+ ) -> Optional [Output ]:
140141 if not outputs :
141142 return
142143
@@ -216,7 +217,7 @@ def _should_record_context(self, context: Optional[PipelineContext]) -> tuple:
216217
217218 # If it's not a FIM prompt, we don't need to check anything else.
218219 if context .input_request .type != "fim" :
219- return True , ' add' , '' # Default to add if not FIM, since no cache check is required
220+ return True , " add" , "" # Default to add if not FIM, since no cache check is required
220221
221222 return fim_cache .could_store_fim_request (context ) # type: ignore
222223
@@ -229,7 +230,7 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
229230 if not should_record :
230231 logger .info ("Skipping record of context, not needed" )
231232 return
232- if action == ' add' :
233+ if action == " add" :
233234 await self .record_request (context .input_request )
234235 await self .record_outputs (context .output_responses , None )
235236 await self .record_alerts (context .alerts_raised , None )
@@ -257,18 +258,61 @@ class DbReader(DbCodeGate):
257258 def __init__ (self , sqlite_path : Optional [str ] = None ):
258259 super ().__init__ (sqlite_path )
259260
261+ async def _execute_select_pydantic_model (
262+ self , model_type : Type [BaseModel ], sql_command : TextClause
263+ ) -> Optional [BaseModel ]:
264+ async with self ._async_db_engine .begin () as conn :
265+ try :
266+ result = await conn .execute (sql_command )
267+ if not result :
268+ return None
269+ rows = [model_type (** row ._asdict ()) for row in result .fetchall () if row ]
270+ return rows
271+ except Exception as e :
272+ logger .error (f"Failed to select model: { model_type } ." , error = str (e ))
273+ return None
274+
260275 async def get_prompts_with_output (self ) -> List [GetPromptWithOutputsRow ]:
261- conn = await self ._async_db_engine .connect ()
262- querier = AsyncQuerier (conn )
263- prompts = [prompt async for prompt in querier .get_prompt_with_outputs ()]
264- await conn .close ()
276+ sql = text (
277+ """
278+ SELECT
279+ p.id, p.timestamp, p.provider, p.request, p.type,
280+ o.id as output_id,
281+ o.output,
282+ o.timestamp as output_timestamp
283+ FROM prompts p
284+ LEFT JOIN outputs o ON p.id = o.prompt_id
285+ ORDER BY o.timestamp DESC
286+ """
287+ )
288+ prompts = await self ._execute_select_pydantic_model (GetPromptWithOutputsRow , sql )
265289 return prompts
266290
267291 async def get_alerts_with_prompt_and_output (self ) -> List [GetAlertsWithPromptAndOutputRow ]:
268- conn = await self ._async_db_engine .connect ()
269- querier = AsyncQuerier (conn )
270- prompts = [prompt async for prompt in querier .get_alerts_with_prompt_and_output ()]
271- await conn .close ()
292+ sql = text (
293+ """
294+ SELECT
295+ a.id,
296+ a.prompt_id,
297+ a.code_snippet,
298+ a.trigger_string,
299+ a.trigger_type,
300+ a.trigger_category,
301+ a.timestamp,
302+ p.timestamp as prompt_timestamp,
303+ p.provider,
304+ p.request,
305+ p.type,
306+ o.id as output_id,
307+ o.output,
308+ o.timestamp as output_timestamp
309+ FROM alerts a
310+ LEFT JOIN prompts p ON p.id = a.prompt_id
311+ LEFT JOIN outputs o ON p.id = o.prompt_id
312+ ORDER BY a.timestamp DESC
313+ """
314+ )
315+ prompts = await self ._execute_select_pydantic_model (GetAlertsWithPromptAndOutputRow , sql )
272316 return prompts
273317
274318
0 commit comments