Skip to content

Commit b8cd27b

Browse files
authored
Revert "fix: fix search config input bug; patch retrieve_utils path set; adjust reader strategy template." (#453)
Revert "fix: fix search config input bug; patch retrieve_utils path set; adju…" This reverts commit 88699f9.
1 parent f3e7338 commit b8cd27b

File tree

9 files changed

+95
-70
lines changed

9 files changed

+95
-70
lines changed

src/memos/api/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def get_reader_config() -> dict[str, Any]:
427427
"config": {
428428
"chunk_type": os.getenv("MEM_READER_CHAT_CHUNK_TYPE", "default"),
429429
"chunk_length": int(os.getenv("MEM_READER_CHAT_CHUNK_TOKEN_SIZE", 1600)),
430-
"chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 10)),
430+
"chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 20)),
431431
"chunk_overlap": int(os.getenv("MEM_READER_CHAT_CHUNK_OVERLAP", 2)),
432432
},
433433
}

src/memos/configs/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig):
184184
),
185185
)
186186

187-
search_strategy: dict[str, Any] | None = Field(
187+
search_strategy: dict[str, bool] | None = Field(
188188
default=None,
189189
description=(
190190
'Set search strategy for this memory configuration.{"bm25": true, "cot": false}'

src/memos/mem_reader/strategy_struct.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _get_llm_response(self, mem_str: str) -> dict:
4343
template = STRATEGY_PROMPT_DICT["chat"][lang]
4444
examples = STRATEGY_PROMPT_DICT["chat"][f"{lang}_example"]
4545
prompt = template.replace("${conversation}", mem_str)
46-
if self.config.remove_prompt_example: # TODO unused
46+
if self.config.remove_prompt_example:
4747
prompt = prompt.replace(examples, "")
4848
messages = [{"role": "user", "content": prompt}]
4949
try:
@@ -112,19 +112,6 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]:
112112

113113
results.append([overlap_item, item])
114114
current_length = overlap_length + content_length
115-
else:
116-
cut_size, cut_overlap = (
117-
self.chat_chunker["chunk_session"],
118-
self.chat_chunker["chunk_overlap"],
119-
)
120-
for items in scene_data:
121-
step = cut_size - cut_overlap
122-
end = len(items) - cut_overlap
123-
if end <= 0:
124-
results.extend([items[:]])
125-
else:
126-
results.extend([items[i : i + cut_size] for i in range(0, end, step)])
127-
128115
elif type == "doc":
129116
parser_config = ParserConfigFactory.model_validate(
130117
{

src/memos/memories/textual/simple_tree.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,7 @@ def __init__(
6666
time_start_bm = time.time()
6767
self.search_strategy = config.search_strategy
6868
self.bm25_retriever = (
69-
EnhancedBM25()
70-
if self.search_strategy and self.search_strategy.get("bm25", False)
71-
else None
69+
EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None
7270
)
7371
logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}")
7472

src/memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,3 @@ class ParsedTaskGoal:
1313
rephrased_query: str | None = None
1414
internet_search: bool = False
1515
goal_type: str | None = None # e.g., 'default', 'explanation', etc.
16-
context: str = ""

src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def find_project_root(marker=".git"):
1717
if (current / marker).exists():
1818
return current
1919
current = current.parent
20-
return Path(".")
20+
logger.warn(f"The project root directory tag file was not found: {marker}")
2121

2222

2323
PROJECT_ROOT = find_project_root()

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030

3131
logger = get_logger(__name__)
3232
COT_DICT = {
33-
"fine": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH},
34-
"fast": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH},
33+
"fast": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH},
34+
"fine": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH},
3535
}
3636

3737

@@ -59,8 +59,12 @@ def __init__(
5959
# Create internet retriever from config if provided
6060
self.internet_retriever = internet_retriever
6161
self.moscube = moscube
62-
self.vec_cot = search_strategy.get("cot", False) if search_strategy else False
63-
self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False
62+
self.vec_cot = (
63+
search_strategy.get("vec_cot", "false") == "true" if search_strategy else False
64+
)
65+
self.use_fast_graph = (
66+
search_strategy.get("fast_graph", "false") == "true" if search_strategy else False
67+
)
6468

6569
self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage")
6670

@@ -283,7 +287,6 @@ def _retrieve_paths(
283287
search_filter,
284288
user_name,
285289
id_filter,
286-
mode=mode,
287290
)
288291
)
289292
tasks.append(
@@ -366,7 +369,6 @@ def _retrieve_from_long_term_and_user(
366369
search_filter: dict | None = None,
367370
user_name: str | None = None,
368371
id_filter: dict | None = None,
369-
mode: str = "fast",
370372
):
371373
"""Retrieve and rerank from LongTermMemory and UserMemory"""
372374
results = []
@@ -375,7 +377,7 @@ def _retrieve_from_long_term_and_user(
375377
# chain of thinking
376378
cot_embeddings = []
377379
if self.vec_cot:
378-
queries = self._cot_query(query, mode=mode, context=parsed_goal.context)
380+
queries = self._cot_query(query)
379381
if len(queries) > 1:
380382
cot_embeddings = self.embedder.embed(queries)
381383
cot_embeddings.extend(query_embedding)
@@ -564,6 +566,7 @@ def _cot_query(
564566
prompt = template.replace("${original_query}", query).replace(
565567
"${split_num_threshold}", str(split_num)
566568
)
569+
logger.info("COT process")
567570

568571
messages = [{"role": "user", "content": prompt}]
569572
try:

src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def parse(
3939
- mode == 'fine': use LLM to parse structured topic/keys/tags
4040
"""
4141
if mode == "fast":
42-
return self._parse_fast(task_description, context=context, **kwargs)
42+
return self._parse_fast(task_description, **kwargs)
4343
elif mode == "fine":
4444
if not self.llm:
4545
raise ValueError("LLM not provided for slow mode.")
@@ -51,7 +51,6 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal:
5151
"""
5252
Fast mode: simple jieba word split.
5353
"""
54-
context = kwargs.get("context", "")
5554
use_fast_graph = kwargs.get("use_fast_graph", False)
5655
if use_fast_graph:
5756
desc_tokenized = self.tokenizer.tokenize_mixed(task_description)
@@ -62,7 +61,6 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal:
6261
goal_type="default",
6362
rephrased_query=task_description,
6463
internet_search=False,
65-
context=context,
6664
)
6765
else:
6866
return ParsedTaskGoal(
@@ -72,7 +70,6 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal:
7270
goal_type="default",
7371
rephrased_query=task_description,
7472
internet_search=False,
75-
context=context,
7673
)
7774

7875
def _parse_fine(
@@ -94,17 +91,16 @@ def _parse_fine(
9491
logger.info(f"Parsing Goal... LLM input is {prompt}")
9592
response = self.llm.generate(messages=[{"role": "user", "content": prompt}])
9693
logger.info(f"Parsing Goal... LLM Response is {response}")
97-
return self._parse_response(response, context=context)
94+
return self._parse_response(response)
9895
except Exception:
9996
logger.warning(f"Fail to fine-parse query {query}: {traceback.format_exc()}")
100-
return self._parse_fast(query, context=context)
97+
return self._parse_fast(query)
10198

102-
def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal:
99+
def _parse_response(self, response: str) -> ParsedTaskGoal:
103100
"""
104101
Parse LLM JSON output safely.
105102
"""
106103
try:
107-
context = kwargs.get("context", "")
108104
response = response.replace("```", "").replace("json", "").strip()
109105
response_json = eval(response)
110106
return ParsedTaskGoal(
@@ -114,7 +110,6 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal:
114110
rephrased_query=response_json.get("rephrased_instruction", None),
115111
internet_search=response_json.get("internet_search", False),
116112
goal_type=response_json.get("goal_type", "default"),
117-
context=context,
118113
)
119114
except Exception as e:
120115
raise ValueError(f"Failed to parse LLM output: {e}\nRaw response:\n{response}") from e

0 commit comments

Comments
 (0)