Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
},
"search_strategy": {
"fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"),
"bm25": bool(os.getenv("BM25_CALL", "false") == "true"),
"cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"),
},
Expand Down Expand Up @@ -937,6 +938,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
},
"search_strategy": {
"fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"),
"bm25": bool(os.getenv("BM25_CALL", "false") == "true"),
"cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"),
},
Expand Down
10 changes: 2 additions & 8 deletions src/memos/memories/textual/simple_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ def __init__(
)
logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}")

self.vec_cot = (
self.search_strategy["cot"]
if self.search_strategy and "cot" in self.search_strategy
else False
)

time_start_rr = time.time()
self.reranker = reranker
logger.info(f"time init: reranker time is: {time.time() - time_start_rr}")
Expand Down Expand Up @@ -189,7 +183,7 @@ def search(
bm25_retriever=self.bm25_retriever,
internet_retriever=None,
moscube=moscube,
vec_cot=self.vec_cot,
search_strategy=self.search_strategy,
)
else:
searcher = Searcher(
Expand All @@ -200,7 +194,7 @@ def search(
bm25_retriever=self.bm25_retriever,
internet_retriever=self.internet_retriever,
moscube=moscube,
vec_cot=self.vec_cot,
search_strategy=self.search_strategy,
)
return searcher.search(
query, top_k, info, mode, memory_type, search_filter, user_name=user_name
Expand Down
7 changes: 2 additions & 5 deletions src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ def __init__(self, config: TreeTextMemoryConfig):
self.bm25_retriever = (
EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None
)
self.vec_cot = (
self.search_strategy["cot"]
if self.search_strategy and "cot" in self.search_strategy
else False
)

if config.reranker is None:
default_cfg = RerankerConfigFactory.model_validate(
Expand Down Expand Up @@ -143,6 +138,7 @@ def get_searcher(
self.reranker,
internet_retriever=None,
moscube=moscube,
search_strategy=self.search_strategy,
)
else:
searcher = Searcher(
Expand All @@ -152,6 +148,7 @@ def get_searcher(
self.reranker,
internet_retriever=self.internet_retriever,
moscube=moscube,
search_strategy=self.search_strategy,
)
return searcher

Expand Down
135 changes: 96 additions & 39 deletions src/memos/memories/textual/tree_text_memory/retrieve/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def retrieve(
search_filter: dict | None = None,
user_name: str | None = None,
id_filter: dict | None = None,
use_fast_graph: bool = False,
) -> list[TextualMemoryItem]:
"""
Perform hybrid memory retrieval:
Expand Down Expand Up @@ -69,7 +70,13 @@ def retrieve(

with ContextThreadPoolExecutor(max_workers=3) as executor:
# Structured graph-based retrieval
future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope, user_name)
future_graph = executor.submit(
self._graph_recall,
parsed_goal,
memory_scope,
user_name,
use_fast_graph=use_fast_graph,
)
# Vector similarity search
future_vector = executor.submit(
self._vector_recall,
Expand Down Expand Up @@ -155,14 +162,15 @@ def retrieve_from_cube(
return list(combined.values())

def _graph_recall(
self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None
self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None, **kwargs
) -> list[TextualMemoryItem]:
"""
Perform structured node-based retrieval from Neo4j.
- keys must match exactly (n.key IN keys)
- tags must overlap with at least 2 input tags
- scope filters by memory_type if provided
"""
use_fast_graph = kwargs.get("use_fast_graph", False)

def process_node(node):
meta = node.get("metadata", {})
Expand All @@ -184,47 +192,96 @@ def process_node(node):
return TextualMemoryItem.from_dict(node)
return None

candidate_ids = set()

# 1) key-based OR branch
if parsed_goal.keys:
key_filters = [
{"field": "key", "op": "in", "value": parsed_goal.keys},
{"field": "memory_type", "op": "=", "value": memory_scope},
]
key_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name)
candidate_ids.update(key_ids)

# 2) tag-based OR branch
if parsed_goal.tags:
tag_filters = [
{"field": "tags", "op": "contains", "value": parsed_goal.tags},
{"field": "memory_type", "op": "=", "value": memory_scope},
]
tag_ids = self.graph_store.get_by_metadata(tag_filters, user_name=user_name)
candidate_ids.update(tag_ids)

# No matches → return empty
if not candidate_ids:
return []
if not use_fast_graph:
candidate_ids = set()

# Load nodes and post-filter
node_dicts = self.graph_store.get_nodes(
list(candidate_ids), include_embedding=False, user_name=user_name
)
# 1) key-based OR branch
if parsed_goal.keys:
key_filters = [
{"field": "key", "op": "in", "value": parsed_goal.keys},
{"field": "memory_type", "op": "=", "value": memory_scope},
]
key_ids = self.graph_store.get_by_metadata(key_filters)
candidate_ids.update(key_ids)

# 2) tag-based OR branch
if parsed_goal.tags:
tag_filters = [
{"field": "tags", "op": "contains", "value": parsed_goal.tags},
{"field": "memory_type", "op": "=", "value": memory_scope},
]
tag_ids = self.graph_store.get_by_metadata(tag_filters)
candidate_ids.update(tag_ids)

final_nodes = []
with ContextThreadPoolExecutor(max_workers=3) as executor:
futures = {executor.submit(process_node, node): i for i, node in enumerate(node_dicts)}
temp_results = [None] * len(node_dicts)
# No matches → return empty
if not candidate_ids:
return []

# Load nodes and post-filter
node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False)

final_nodes = []
for node in node_dicts:
meta = node.get("metadata", {})
node_key = meta.get("key")
node_tags = meta.get("tags", []) or []

keep = False
# key equals to node_key
if parsed_goal.keys and node_key in parsed_goal.keys:
keep = True
# overlap tags more than 2
elif parsed_goal.tags:
overlap = len(set(node_tags) & set(parsed_goal.tags))
if overlap >= 2:
keep = True
if keep:
final_nodes.append(TextualMemoryItem.from_dict(node))
return final_nodes
else:
candidate_ids = set()

# 1) key-based OR branch
if parsed_goal.keys:
key_filters = [
{"field": "key", "op": "in", "value": parsed_goal.keys},
{"field": "memory_type", "op": "=", "value": memory_scope},
]
key_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name)
candidate_ids.update(key_ids)

# 2) tag-based OR branch
if parsed_goal.tags:
tag_filters = [
{"field": "tags", "op": "contains", "value": parsed_goal.tags},
{"field": "memory_type", "op": "=", "value": memory_scope},
]
tag_ids = self.graph_store.get_by_metadata(tag_filters, user_name=user_name)
candidate_ids.update(tag_ids)

# No matches → return empty
if not candidate_ids:
return []

# Load nodes and post-filter
node_dicts = self.graph_store.get_nodes(
list(candidate_ids), include_embedding=False, user_name=user_name
)

final_nodes = []
with ContextThreadPoolExecutor(max_workers=3) as executor:
futures = {
executor.submit(process_node, node): i for i, node in enumerate(node_dicts)
}
temp_results = [None] * len(node_dicts)

for future in concurrent.futures.as_completed(futures):
original_index = futures[future]
result = future.result()
temp_results[original_index] = result
for future in concurrent.futures.as_completed(futures):
original_index = futures[future]
result = future.result()
temp_results[original_index] = result

final_nodes = [result for result in temp_results if result is not None]
return final_nodes
final_nodes = [result for result in temp_results if result is not None]
return final_nodes

def _vector_recall(
self,
Expand Down
13 changes: 11 additions & 2 deletions src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
bm25_retriever: EnhancedBM25 | None = None,
internet_retriever: None = None,
moscube: bool = False,
vec_cot: bool = False,
search_strategy: dict | None = None,
):
self.graph_store = graph_store
self.embedder = embedder
Expand All @@ -59,7 +59,12 @@ def __init__(
# Create internet retriever from config if provided
self.internet_retriever = internet_retriever
self.moscube = moscube
self.vec_cot = vec_cot
self.vec_cot = (
search_strategy.get("vec_cot", "false") == "true" if search_strategy else False
)
self.use_fast_graph = (
search_strategy.get("fast_graph", "false") == "true" if search_strategy else False
)

self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage")

Expand Down Expand Up @@ -226,6 +231,7 @@ def _parse_task(
context="\n".join(context),
conversation=info.get("chat_history", []),
mode=mode,
use_fast_graph=self.use_fast_graph,
)

query = parsed_goal.rephrased_query or query
Expand Down Expand Up @@ -340,6 +346,7 @@ def _retrieve_from_working_memory(
search_filter=search_filter,
user_name=user_name,
id_filter=id_filter,
use_fast_graph=self.use_fast_graph,
)
return self.reranker.rerank(
query=query,
Expand Down Expand Up @@ -390,6 +397,7 @@ def _retrieve_from_long_term_and_user(
search_filter=search_filter,
user_name=user_name,
id_filter=id_filter,
use_fast_graph=self.use_fast_graph,
)
)
if memory_type in ["All", "UserMemory"]:
Expand All @@ -404,6 +412,7 @@ def _retrieve_from_long_term_and_user(
search_filter=search_filter,
user_name=user_name,
id_filter=id_filter,
use_fast_graph=self.use_fast_graph,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def parse(
context: str = "",
conversation: list[dict] | None = None,
mode: str = "fast",
**kwargs,
) -> ParsedTaskGoal:
"""
Parse user input into structured semantic layers.
Expand All @@ -38,27 +39,38 @@ def parse(
- mode == 'fine': use LLM to parse structured topic/keys/tags
"""
if mode == "fast":
return self._parse_fast(task_description)
return self._parse_fast(task_description, **kwargs)
elif mode == "fine":
if not self.llm:
raise ValueError("LLM not provided for slow mode.")
return self._parse_fine(task_description, context, conversation)
else:
raise ValueError(f"Unknown mode: {mode}")

def _parse_fast(self, task_description: str, limit_num: int = 5) -> ParsedTaskGoal:
def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal:
"""
Fast mode: simple jieba word split.
"""
desc_tokenized = self.tokenizer.tokenize_mixed(task_description)
return ParsedTaskGoal(
memories=[task_description],
keys=desc_tokenized,
tags=desc_tokenized,
goal_type="default",
rephrased_query=task_description,
internet_search=False,
)
use_fast_graph = kwargs.get("use_fast_graph", False)
if use_fast_graph:
desc_tokenized = self.tokenizer.tokenize_mixed(task_description)
return ParsedTaskGoal(
memories=[task_description],
keys=desc_tokenized,
tags=desc_tokenized,
goal_type="default",
rephrased_query=task_description,
internet_search=False,
)
else:
return ParsedTaskGoal(
memories=[task_description],
keys=[task_description],
tags=[],
goal_type="default",
rephrased_query=task_description,
internet_search=False,
)

def _parse_fine(
self, query: str, context: str = "", conversation: list[dict] | None = None
Expand Down