Skip to content

Commit ce34bd1

Browse files
wustzdycaocuilonglijicodeCaralHsi
authored
add polardb (#395)
* add polardb.py * add polardb.py * add polar factory * delete * update get_memory_count * update get_memory_count * update node_not_exist * update remove_oldest_memory * fix * update get_node * update get_node * update update_node * update delete_node * add edge * add create_extension,create_graph,create_edge * add add_edge * edge_exist * edge_exist * edge_exist * update edge_exists * update polardb.py * update get_children_with_embeddings * update get_children_with_embeddings * update get_subgraph * update get_grouped_counts * update get_all_memory_items * update export_graph * remove * insert Memory * fix add_node * fix polardb.py * fix * fix get_subgraph * fix * get_grouped_counts * update get_by_metadata * get_grouped_counts * update get_grouped_counts * update get_grouped_counts * get_grouped_counts * update get_nodes * update search_by_embedding filter user_name * update search_by_embedding filter user_name * add filter user_name for update_node * get_structure_optimization_candidates * add filter user_name for update_node * fix * fix * fix * feat: 增加polardb的启动配置 * fix * fix * fix get_structure_optimization_candidates * fix get_all_memory_items * fix get_all_memory_items * remove embedding for get_nodes * fix get_structure_optimization_candidates * add _parse_node_new * update get_all_memory_items * update get_all_memory_items * update get_all_memory_items for include_embedding * feat: server router add polardb config * feat: server router add polardb config * update get_all_memory_items for include_embedding False * update get_all_memory_items for include_embedding False * fix * fix get_all_memory_items * update get_all_memory_items for include_embedding False * fix get_all_memory_items * update get_all_memory_items for include_embedding False * update get_grouped_counts * update get_grouped_counts * add_node and graph_id * fix * fix get_all_memory_items false * fix * fix get_all_memory_items true * fix * fix * fix * fix export_graph * fix export_graph * fix get_by_metadata * update get_neighbors_by_tag * update get_neighbors_by_tag * update get_neighbors_by_tag * fix * fix * add import_graph * fix * add get_edges * add clear * get_neighbors_by_tag * get_neighbors_by_tag * update get_by_metadata * search_by_emdedding remove embedding * fix:parseJson.py * fix:get_my_metadata * fix * fix get_by_metadata result * update polardb.py * fix _coerce_metadata * feat: add rerank time * feat: add rerank time * fix:node_not_exist * import node * import node * feat: fix merge_config_with_default * import node * fix * fix * feat: fix polardb * feat: fix scheduler method name * fix get_by_metadata for "query": "How long ago was Caroline's 18th birthday?" * fix get_by_metadata for "query": "How long ago was Caroline's 18th birthday?" * fix get_node format_param_value * feat: fix CONFIG * fix * feat: fix import * feat: delete test file * feat: fix polardb * feat: fix recall * Comment out unused configuration handling code Commented out code related to auto_create and embedding_dimension handling. * fix * feat: fix polardb * import polardb * feat: fix polardb * fix * feat: fix polardb * fix * fix * feat: fix polardb * feat: delete polardb * feat: fix utils * feat: fix polardb * feat: format polardb * feat: format utils --------- Co-authored-by: ccl <[email protected]> Co-authored-by: liji <[email protected]> Co-authored-by: CaralHsi <[email protected]>
1 parent 84adda6 commit ce34bd1

File tree

11 files changed

+2953
-37
lines changed

11 files changed

+2953
-37
lines changed

docker/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,4 @@ volcengine-python-sdk==4.0.6
157157
watchfiles==1.1.0
158158
websockets==15.0.1
159159
xlrd==2.0.2
160-
xlsxwriter==3.2.5
160+
xlsxwriter==3.2.5

src/memos/api/config.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,32 @@ def get_milvus_config():
309309
"password": os.getenv("MILVUS_PASSWORD", "12345678"),
310310
}
311311

312+
@staticmethod
313+
def get_polardb_config(user_id: str | None = None) -> dict[str, Any]:
314+
"""Get PolarDB configuration."""
315+
use_multi_db = os.getenv("POLAR_DB_USE_MULTI_DB", "false").lower() == "true"
316+
317+
if use_multi_db:
318+
# Multi-DB mode: each user gets their own database (physical isolation)
319+
db_name = f"memos{user_id.replace('-', '')}" if user_id else "memos_default"
320+
user_name = None
321+
else:
322+
# Shared-DB mode: all users share one database with user_name tag (logical isolation)
323+
db_name = os.getenv("POLAR_DB_DB_NAME", "shared_memos_db")
324+
user_name = f"memos{user_id.replace('-', '')}" if user_id else "memos_default"
325+
326+
return {
327+
"host": os.getenv("POLAR_DB_HOST", "localhost"),
328+
"port": int(os.getenv("POLAR_DB_PORT", "5432")),
329+
"user": os.getenv("POLAR_DB_USER", "root"),
330+
"password": os.getenv("POLAR_DB_PASSWORD", "123456"),
331+
"db_name": db_name,
332+
"user_name": user_name,
333+
"use_multi_db": use_multi_db,
334+
"auto_create": True,
335+
"embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)),
336+
}
337+
312338
@staticmethod
313339
def get_mysql_config() -> dict[str, Any]:
314340
"""Get MySQL configuration."""
@@ -540,6 +566,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
540566
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id)
541567
neo4j_config = APIConfig.get_neo4j_config(user_id)
542568
nebular_config = APIConfig.get_nebular_config(user_id)
569+
polardb_config = APIConfig.get_polardb_config(user_id)
543570
internet_config = (
544571
APIConfig.get_internet_config()
545572
if os.getenv("ENABLE_INTERNET", "false").lower() == "true"
@@ -549,6 +576,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
549576
"neo4j-community": neo4j_community_config,
550577
"neo4j": neo4j_config,
551578
"nebular": nebular_config,
579+
"polardb": polardb_config,
552580
}
553581
graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower()
554582
if graph_db_backend in graph_db_backend_map:
@@ -607,10 +635,12 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
607635
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default")
608636
neo4j_config = APIConfig.get_neo4j_config(user_id="default")
609637
nebular_config = APIConfig.get_nebular_config(user_id="default")
638+
polardb_config = APIConfig.get_polardb_config(user_id="default")
610639
graph_db_backend_map = {
611640
"neo4j-community": neo4j_community_config,
612641
"neo4j": neo4j_config,
613642
"nebular": nebular_config,
643+
"polardb": polardb_config,
614644
}
615645
internet_config = (
616646
APIConfig.get_internet_config()

src/memos/api/routers/server_router.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
6969
"neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id),
7070
"neo4j": APIConfig.get_neo4j_config(user_id=user_id),
7171
"nebular": APIConfig.get_nebular_config(user_id=user_id),
72+
"polardb": APIConfig.get_polardb_config(user_id=user_id),
7273
}
7374

7475
graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower()

src/memos/configs/graph_db.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,59 @@ def validate_config(self):
154154
return self
155155

156156

157+
class PolarDBGraphDBConfig(BaseConfig):
158+
"""
159+
PolarDB-specific configuration.
160+
161+
Key concepts:
162+
- `db_name`: The name of the target PolarDB database
163+
- `user_name`: Used for logical tenant isolation if needed
164+
- `auto_create`: Whether to automatically create the target database if it does not exist
165+
- `use_multi_db`: Whether to use multi-database mode for physical isolation
166+
167+
Example:
168+
---
169+
host = "localhost"
170+
port = 5432
171+
user = "postgres"
172+
password = "password"
173+
db_name = "memos_db"
174+
user_name = "alice"
175+
use_multi_db = True
176+
auto_create = True
177+
"""
178+
179+
host: str = Field(..., description="Database host")
180+
port: int = Field(default=5432, description="Database port")
181+
user: str = Field(..., description="Database user")
182+
password: str = Field(..., description="Database password")
183+
db_name: str = Field(..., description="The name of the target PolarDB database")
184+
user_name: str | None = Field(
185+
default=None,
186+
description="Logical user or tenant ID for data isolation (optional, used in metadata tagging)",
187+
)
188+
auto_create: bool = Field(
189+
default=False,
190+
description="Whether to auto-create the database if it does not exist",
191+
)
192+
use_multi_db: bool = Field(
193+
default=True,
194+
description=(
195+
"If True: use multi-database mode for physical isolation; "
196+
"each tenant typically gets a separate database. "
197+
"If False: use a single shared database with logical isolation by user_name."
198+
),
199+
)
200+
embedding_dimension: int = Field(default=1024, description="Dimension of vector embedding")
201+
202+
@model_validator(mode="after")
203+
def validate_config(self):
204+
"""Validate config."""
205+
if not self.db_name:
206+
raise ValueError("`db_name` must be provided")
207+
return self
208+
209+
157210
class GraphDBConfigFactory(BaseModel):
158211
backend: str = Field(..., description="Backend for graph database")
159212
config: dict[str, Any] = Field(..., description="Configuration for the graph database backend")
@@ -162,6 +215,7 @@ class GraphDBConfigFactory(BaseModel):
162215
"neo4j": Neo4jGraphDBConfig,
163216
"neo4j-community": Neo4jCommunityGraphDBConfig,
164217
"nebular": NebulaGraphDBConfig,
218+
"polardb": PolarDBGraphDBConfig,
165219
}
166220

167221
@field_validator("backend")

src/memos/graph_dbs/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from memos.graph_dbs.nebular import NebulaGraphDB
66
from memos.graph_dbs.neo4j import Neo4jGraphDB
77
from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB
8+
from memos.graph_dbs.polardb import PolarDBGraphDB
89

910

1011
class GraphStoreFactory(BaseGraphDB):
@@ -14,6 +15,7 @@ class GraphStoreFactory(BaseGraphDB):
1415
"neo4j": Neo4jGraphDB,
1516
"neo4j-community": Neo4jCommunityGraphDB,
1617
"nebular": NebulaGraphDB,
18+
"polardb": PolarDBGraphDB,
1719
}
1820

1921
@classmethod

0 commit comments

Comments
 (0)