Skip to content

Commit e79a9ab

Browse files
authored
feat: fix polardb value (#445)
1 parent fd56f64 commit e79a9ab

File tree

1 file changed

+37
-25
lines changed

1 file changed

+37
-25
lines changed

src/memos/graph_dbs/polardb.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in
363363
WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype
364364
"""
365365
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
366-
params = [f'"{memory_type}"', f'"{user_name}"']
366+
params = [self.format_param_value(memory_type), self.format_param_value(user_name)]
367367

368368
# Get a connection from the pool
369369
conn = self._get_connection()
@@ -389,7 +389,7 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int:
389389
"""
390390
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
391391
query += "\nLIMIT 1"
392-
params = [f'"{scope}"', f'"{user_name}"']
392+
params = [self.format_param_value(scope), self.format_param_value(user_name)]
393393

394394
# Get a connection from the pool
395395
conn = self._get_connection()
@@ -427,7 +427,11 @@ def remove_oldest_memory(
427427
ORDER BY ag_catalog.agtype_access_operator(properties, '"updated_at"'::agtype) DESC
428428
OFFSET %s
429429
"""
430-
select_params = [f'"{memory_type}"', f'"{user_name}"', keep_latest]
430+
select_params = [
431+
self.format_param_value(memory_type),
432+
self.format_param_value(user_name),
433+
keep_latest,
434+
]
431435
conn = self._get_connection()
432436
try:
433437
with conn.cursor() as cursor:
@@ -501,19 +505,23 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N
501505
SET properties = %s, embedding = %s
502506
WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
503507
"""
504-
params = [json.dumps(properties), json.dumps(embedding_vector), f'"{id}"']
508+
params = [
509+
json.dumps(properties),
510+
json.dumps(embedding_vector),
511+
self.format_param_value(id),
512+
]
505513
else:
506514
query = f"""
507515
UPDATE "{self.db_name}_graph"."Memory"
508516
SET properties = %s
509517
WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
510518
"""
511-
params = [json.dumps(properties), f'"{id}"']
519+
params = [json.dumps(properties), self.format_param_value(id)]
512520

513521
# Only add user filter when user_name is provided
514522
if user_name is not None:
515523
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
516-
params.append(f'"{user_name}"')
524+
params.append(self.format_param_value(user_name))
517525

518526
# Get a connection from the pool
519527
conn = self._get_connection()
@@ -538,12 +546,12 @@ def delete_node(self, id: str, user_name: str | None = None) -> None:
538546
DELETE FROM "{self.db_name}_graph"."Memory"
539547
WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
540548
"""
541-
params = [f'"{id}"']
549+
params = [self.format_param_value(id)]
542550

543551
# Only add user filter when user_name is provided
544552
if user_name is not None:
545553
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
546-
params.append(f'"{user_name}"')
554+
params.append(self.format_param_value(user_name))
547555

548556
# Get a connection from the pool
549557
conn = self._get_connection()
@@ -831,28 +839,17 @@ def get_node(
831839

832840
select_fields = "id, properties, embedding" if include_embedding else "id, properties"
833841

834-
# Helper function to format parameter value
835-
def format_param_value(value: str) -> str:
836-
"""Format parameter value to handle both quoted and unquoted formats"""
837-
# Remove outer quotes if they exist
838-
if value.startswith('"') and value.endswith('"'):
839-
# Already has double quotes, return as is
840-
return value
841-
else:
842-
# Add double quotes
843-
return f'"{value}"'
844-
845842
query = f"""
846843
SELECT {select_fields}
847844
FROM "{self.db_name}_graph"."Memory"
848845
WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
849846
"""
850-
params = [format_param_value(id)]
847+
params = [self.format_param_value(id)]
851848

852849
# Only add user filter when user_name is provided
853850
if user_name is not None:
854851
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
855-
params.append(format_param_value(user_name))
852+
params.append(self.format_param_value(user_name))
856853

857854
conn = self._get_connection()
858855
try:
@@ -930,7 +927,7 @@ def get_nodes(
930927
where_conditions.append(
931928
"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = %s::agtype"
932929
)
933-
params.append(f"{id_val}")
930+
params.append(self.format_param_value(id_val))
934931

935932
where_clause = " OR ".join(where_conditions)
936933

@@ -942,7 +939,7 @@ def get_nodes(
942939

943940
user_name = user_name if user_name else self.config.user_name
944941
query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
945-
params.append(f'"{user_name}"')
942+
params.append(self.format_param_value(user_name))
946943

947944
conn = self._get_connection()
948945
try:
@@ -2616,7 +2613,7 @@ def get_neighbors_by_tag(
26162613
exclude_conditions.append(
26172614
"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) != %s::agtype"
26182615
)
2619-
params.append(f'"{exclude_id}"')
2616+
params.append(self.format_param_value(exclude_id))
26202617
where_clauses.append(f"({' AND '.join(exclude_conditions)})")
26212618

26222619
# Status filter - keep only 'activated'
@@ -2633,7 +2630,7 @@ def get_neighbors_by_tag(
26332630
where_clauses.append(
26342631
"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
26352632
)
2636-
params.append(f'"{user_name}"')
2633+
params.append(self.format_param_value(user_name))
26372634

26382635
# Testing showed no data; annotate.
26392636
where_clauses.append(
@@ -3022,3 +3019,18 @@ def _convert_graph_edges(self, core_node: dict) -> dict:
30223019
if tgt in id_map:
30233020
edge["target"] = id_map[tgt]
30243021
return data
3022+
3023+
def format_param_value(self, value: str | None) -> str:
3024+
"""Format parameter value to handle both quoted and unquoted formats"""
3025+
# Handle None value
3026+
if value is None:
3027+
logger.warning(f"format_param_value: value is None")
3028+
return "null"
3029+
3030+
# Remove outer quotes if they exist
3031+
if value.startswith('"') and value.endswith('"'):
3032+
# Already has double quotes, return as is
3033+
return value
3034+
else:
3035+
# Add double quotes
3036+
return f'"{value}"'

0 commit comments

Comments
 (0)