From 8a3b9aa6d247fc57dbfd05e8c51f2c49552af01b Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Thu, 30 Oct 2025 10:35:04 +0800 Subject: [PATCH] Revert "Feat: add recall strategy (#414)" This reverts commit a375911827b4c6fe3fd82758a23a6b6cb0c9adec. --- poetry.lock | 50 +-- pyproject.toml | 4 +- src/memos/api/config.py | 34 +- src/memos/configs/mem_reader.py | 9 - src/memos/configs/memory.py | 7 - src/memos/llms/openai.py | 13 +- src/memos/mem_reader/factory.py | 2 - src/memos/mem_reader/strategy_struct.py | 138 ------- src/memos/memories/textual/simple_tree.py | 18 - src/memos/memories/textual/tree.py | 16 - .../tree_text_memory/retrieve/bm25_util.py | 186 --------- .../tree_text_memory/retrieve/recall.py | 107 +---- .../retrieve/retrieve_utils.py | 378 ------------------ .../tree_text_memory/retrieve/searcher.py | 88 +--- .../retrieve/task_goal_parser.py | 7 +- .../templates/mem_reader_strategy_prompts.py | 279 ------------- src/memos/templates/mem_search_prompts.py | 93 ----- .../textual/test_tree_task_goal_parser.py | 5 + 18 files changed, 41 insertions(+), 1393 deletions(-) delete mode 100644 src/memos/mem_reader/strategy_struct.py delete mode 100644 src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py delete mode 100644 src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py delete mode 100644 src/memos/templates/mem_reader_strategy_prompts.py delete mode 100644 src/memos/templates/mem_search_prompts.py diff --git a/poetry.lock b/poetry.lock index 926d580fb..44265bca8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -192,19 +192,6 @@ torch = ">=1.0.0" tqdm = ">=4.31.1" transformers = ">=3.0.0" -[[package]] -name = "cachetools" -version = "6.2.1" -description = "Extensible memoizing collections and decorators" -optional = true -python-versions = ">=3.9" -groups = ["main"] -markers = "extra == \"all\"" -files = [ - {file = "cachetools-6.2.1-py3-none-any.whl", hash = "sha256:09868944b6dde876dfd44e1d47e18484541eaf12f26f29b7af91b26cc892d701"}, - {file = "cachetools-6.2.1.tar.gz", hash = "sha256:3f391e4bd8f8bf0931169baf7456cc822705f4e2a31f840d218f445b9a854201"}, -] - [[package]] name = "certifi" version = "2025.7.14" @@ -1566,18 +1553,6 @@ files = [ {file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"}, ] -[[package]] -name = "jieba" -version = "0.42" -description = "Chinese Words Segmentation Utilities" -optional = true -python-versions = "*" -groups = ["main"] -markers = "extra == \"all\"" -files = [ - {file = "jieba-0.42.tar.gz", hash = "sha256:34a3c960cc2943d9da16d6d2565110cf5f305921a67413dddf04f84de69c939b"}, -] - [[package]] name = "jinja2" version = "3.1.6" @@ -4148,25 +4123,6 @@ urllib3 = ">=1.26.14,<3" fastembed = ["fastembed (>=0.7,<0.8)"] fastembed-gpu = ["fastembed-gpu (>=0.7,<0.8)"] -[[package]] -name = "rank-bm25" -version = "0.2.2" -description = "Various BM25 algorithms for document ranking" -optional = true -python-versions = "*" -groups = ["main"] -markers = "extra == \"all\"" -files = [ - {file = "rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae"}, - {file = "rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d"}, -] - -[package.dependencies] -numpy = "*" - -[package.extras] -dev = ["pytest"] - [[package]] name = "redis" version = "6.2.0" @@ -6396,7 +6352,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["cachetools", "chonkie", "datasketch", "jieba", "markitdown", "neo4j", "pika", "pymilvus", "pymysql", "qdrant-client", "rank-bm25", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] +all = ["chonkie", "datasketch", "markitdown", "neo4j", "pika", "pymilvus", "pymysql", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] mem-reader = ["chonkie", "markitdown"] mem-scheduler = ["pika", "redis"] mem-user = ["pymysql"] @@ -6406,4 +6362,4 @@ tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "ec17679a44205ada4494fbc485ac592883281fde273d5e73d6b8cbc6f7f9ed10" +content-hash = "3f0d0c9a996f87d945ef8bf83eed3e20f8c420b6b39e12012d0147eda2bf4d38" diff --git a/pyproject.toml b/pyproject.toml index 2f88797a8..3745582f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,9 +107,7 @@ all = [ "markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)", "pymilvus (>=2.6.1,<3.0.0)", "datasketch (>=1.6.5,<2.0.0)", - "jieba (>=0.38.1,<0.42.1)", - "rank-bm25 (>=0.2.2)", - "cachetools (>=6.0.0)", + # NOT exist in the above optional groups # Because they are either huge-size dependencies or infrequently used dependencies. # We kindof don't want users to install them. diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 405e8068d..7ac882d6c 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -419,23 +419,9 @@ def get_embedder_config() -> dict[str, Any]: }, } - @staticmethod - def get_reader_config() -> dict[str, Any]: - """Get reader configuration.""" - return { - "backend": os.getenv("MEM_READER_BACKEND", "simple_struct"), - "config": { - "chunk_type": os.getenv("MEM_READER_CHAT_CHUNK_TYPE", "default"), - "chunk_length": int(os.getenv("MEM_READER_CHAT_CHUNK_TOKEN_SIZE", 1600)), - "chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 20)), - "chunk_overlap": int(os.getenv("MEM_READER_CHAT_CHUNK_OVERLAP", 2)), - }, - } - @staticmethod def get_internet_config() -> dict[str, Any]: """Get embedder configuration.""" - reader_config = APIConfig.get_reader_config() return { "backend": "bocha", "config": { @@ -443,7 +429,7 @@ def get_internet_config() -> dict[str, Any]: "max_results": 15, "num_per_request": 10, "reader": { - "backend": reader_config["backend"], + "backend": "simple_struct", "config": { "llm": { "backend": "openai", @@ -469,7 +455,6 @@ def get_internet_config() -> dict[str, Any]: "min_sentences_per_chunk": 1, }, }, - "chat_chunker": reader_config, }, }, }, @@ -671,8 +656,6 @@ def get_product_default_config() -> dict[str, Any]: openai_config = APIConfig.get_openai_config() qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() - reader_config = APIConfig.get_reader_config() - backend_model = { "openai": openai_config, "huggingface": qwen_config, @@ -684,7 +667,7 @@ def get_product_default_config() -> dict[str, Any]: "user_id": os.getenv("MOS_USER_ID", "root"), "chat_model": {"backend": backend, "config": backend_model[backend]}, "mem_reader": { - "backend": reader_config["backend"], + "backend": "simple_struct", "config": { "llm": APIConfig.get_memreader_config(), "embedder": APIConfig.get_embedder_config(), @@ -697,7 +680,6 @@ def get_product_default_config() -> dict[str, Any]: "min_sentences_per_chunk": 1, }, }, - "chat_chunker": reader_config, }, }, "enable_textual_memory": True, @@ -768,7 +750,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() mysql_config = APIConfig.get_mysql_config() - reader_config = APIConfig.get_reader_config() backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai") backend_model = { "openai": openai_config, @@ -783,7 +764,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "config": backend_model[backend], }, "mem_reader": { - "backend": reader_config["backend"], + "backend": "simple_struct", "config": { "llm": APIConfig.get_memreader_config(), "embedder": APIConfig.get_embedder_config(), @@ -796,7 +777,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "min_sentences_per_chunk": 1, }, }, - "chat_chunker": reader_config, }, }, "enable_textual_memory": True, @@ -865,10 +845,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, - "search_strategy": { - "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), - "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), - }, }, }, "act_mem": {} @@ -936,10 +912,6 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, - "search_strategy": { - "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), - "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), - }, "mode": os.getenv("ASYNC_MODE", "sync"), }, }, diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index dc8d37a35..1c62087a3 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -36,19 +36,11 @@ def parse_datetime(cls, value): description="whether remove example in memory extraction prompt to save token", ) - chat_chunker: dict[str, Any] = Field( - default=None, description="Configuration for the MemReader chat chunk strategy" - ) - class SimpleStructMemReaderConfig(BaseMemReaderConfig): """SimpleStruct MemReader configuration class.""" -class StrategyStructMemReaderConfig(BaseMemReaderConfig): - """StrategyStruct MemReader configuration class.""" - - class MemReaderConfigFactory(BaseConfig): """Factory class for creating MemReader configurations.""" @@ -57,7 +49,6 @@ class MemReaderConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "simple_struct": SimpleStructMemReaderConfig, - "strategy_struct": StrategyStructMemReaderConfig, } @field_validator("backend") diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 49320fbf5..bf2493567 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -184,13 +184,6 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): ), ) - search_strategy: dict[str, bool] | None = Field( - default=None, - description=( - 'Set search strategy for this memory configuration.{"bm25": true, "cot": false}' - ), - ) - mode: str | None = Field( default="sync", description=("whether use asynchronous mode in memory add"), diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 1a1703340..ca1df5c1f 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -58,18 +58,15 @@ def clear_cache(cls): logger.info("OpenAI LLM instance cache cleared") @timed(log=True, log_prefix="OpenAI LLM") - def generate(self, messages: MessageList, **kwargs) -> str: - """Generate a response from OpenAI LLM, optionally overriding generation params.""" - temperature = kwargs.get("temperature", self.config.temperature) - max_tokens = kwargs.get("max_tokens", self.config.max_tokens) - top_p = kwargs.get("top_p", self.config.top_p) + def generate(self, messages: MessageList) -> str: + """Generate a response from OpenAI LLM.""" response = self.client.chat.completions.create( model=self.config.model_name_or_path, messages=messages, extra_body=self.config.extra_body, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + top_p=self.config.top_p, ) logger.info(f"Response from OpenAI: {response.model_dump_json()}") response_content = response.choices[0].message.content diff --git a/src/memos/mem_reader/factory.py b/src/memos/mem_reader/factory.py index 2205a0215..52eed8d9d 100644 --- a/src/memos/mem_reader/factory.py +++ b/src/memos/mem_reader/factory.py @@ -3,7 +3,6 @@ from memos.configs.mem_reader import MemReaderConfigFactory from memos.mem_reader.base import BaseMemReader from memos.mem_reader.simple_struct import SimpleStructMemReader -from memos.mem_reader.strategy_struct import StrategyStructMemReader from memos.memos_tools.singleton import singleton_factory @@ -12,7 +11,6 @@ class MemReaderFactory(BaseMemReader): backend_to_class: ClassVar[dict[str, Any]] = { "simple_struct": SimpleStructMemReader, - "strategy_struct": StrategyStructMemReader, } @classmethod diff --git a/src/memos/mem_reader/strategy_struct.py b/src/memos/mem_reader/strategy_struct.py deleted file mode 100644 index 2cac1652a..000000000 --- a/src/memos/mem_reader/strategy_struct.py +++ /dev/null @@ -1,138 +0,0 @@ -import os - -from abc import ABC - -from memos import log -from memos.configs.mem_reader import StrategyStructMemReaderConfig -from memos.configs.parser import ParserConfigFactory -from memos.mem_reader.simple_struct import SimpleStructMemReader, detect_lang -from memos.parsers.factory import ParserFactory -from memos.templates.mem_reader_prompts import ( - SIMPLE_STRUCT_DOC_READER_PROMPT, - SIMPLE_STRUCT_DOC_READER_PROMPT_ZH, - SIMPLE_STRUCT_MEM_READER_EXAMPLE, - SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, -) -from memos.templates.mem_reader_strategy_prompts import ( - STRATEGY_STRUCT_MEM_READER_PROMPT, - STRATEGY_STRUCT_MEM_READER_PROMPT_ZH, -) - - -logger = log.get_logger(__name__) -STRATEGY_PROMPT_DICT = { - "chat": { - "en": STRATEGY_STRUCT_MEM_READER_PROMPT, - "zh": STRATEGY_STRUCT_MEM_READER_PROMPT_ZH, - "en_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE, - "zh_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, - }, - "doc": {"en": SIMPLE_STRUCT_DOC_READER_PROMPT, "zh": SIMPLE_STRUCT_DOC_READER_PROMPT_ZH}, -} - - -class StrategyStructMemReader(SimpleStructMemReader, ABC): - """Naive implementation of MemReader.""" - - def __init__(self, config: StrategyStructMemReaderConfig): - super().__init__(config) - self.chat_chunker = config.chat_chunker["config"] - - def _get_llm_response(self, mem_str: str) -> dict: - lang = detect_lang(mem_str) - template = STRATEGY_PROMPT_DICT["chat"][lang] - examples = STRATEGY_PROMPT_DICT["chat"][f"{lang}_example"] - prompt = template.replace("${conversation}", mem_str) - if self.config.remove_prompt_example: - prompt = prompt.replace(examples, "") - messages = [{"role": "user", "content": prompt}] - try: - response_text = self.llm.generate(messages) - response_json = self.parse_json_result(response_text) - except Exception as e: - logger.error(f"[LLM] Exception during chat generation: {e}") - response_json = { - "memory list": [ - { - "key": mem_str[:10], - "memory_type": "UserMemory", - "value": mem_str, - "tags": [], - } - ], - "summary": mem_str, - } - return response_json - - def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: - """ - Get raw information from scene_data. - If scene_data contains dictionaries, convert them to strings. - If scene_data contains file paths, parse them using the parser. - - Args: - scene_data: List of dialogue information or document paths - type: Type of scene data: ['doc', 'chat'] - Returns: - List of strings containing the processed scene data - """ - results = [] - - if type == "chat": - if self.chat_chunker["chunk_type"] == "content_length": - content_len_thredshold = self.chat_chunker["chunk_length"] - for items in scene_data: - if not items: - continue - - results.append([]) - current_length = 0 - - for _i, item in enumerate(items): - content_length = ( - len(item.get("content", "")) - if isinstance(item, dict) - else len(str(item)) - ) - if not results[-1]: - results[-1].append(item) - current_length = content_length - continue - - if current_length + content_length <= content_len_thredshold: - results[-1].append(item) - current_length += content_length - else: - overlap_item = results[-1][-1] - overlap_length = ( - len(overlap_item.get("content", "")) - if isinstance(overlap_item, dict) - else len(str(overlap_item)) - ) - - results.append([overlap_item, item]) - current_length = overlap_length + content_length - elif type == "doc": - parser_config = ParserConfigFactory.model_validate( - { - "backend": "markitdown", - "config": {}, - } - ) - parser = ParserFactory.from_config(parser_config) - for item in scene_data: - try: - if os.path.exists(item): - try: - parsed_text = parser.parse(item) - results.append({"file": item, "text": parsed_text}) - except Exception as e: - logger.error(f"[SceneParser] Error parsing {item}: {e}") - continue - else: - parsed_text = item - results.append({"file": "pure_text", "text": parsed_text}) - except Exception as e: - print(f"Error parsing file {item}: {e!s}") - - return results diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 6974dbe8f..8ce81a8bd 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -12,7 +12,6 @@ from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager -from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.base import BaseReranker from memos.types import MessageList @@ -63,19 +62,6 @@ def __init__( self.graph_store: Neo4jGraphDB = graph_db logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}") - time_start_bm = time.time() - self.search_strategy = config.search_strategy - self.bm25_retriever = ( - EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None - ) - logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}") - - self.vec_cot = ( - self.search_strategy["cot"] - if self.search_strategy and "cot" in self.search_strategy - else False - ) - time_start_rr = time.time() self.reranker = reranker logger.info(f"time init: reranker time is: {time.time() - time_start_rr}") @@ -186,10 +172,8 @@ def search( self.graph_store, self.embedder, self.reranker, - bm25_retriever=self.bm25_retriever, internet_retriever=None, moscube=moscube, - vec_cot=self.vec_cot, ) else: searcher = Searcher( @@ -197,10 +181,8 @@ def search( self.graph_store, self.embedder, self.reranker, - bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, moscube=moscube, - vec_cot=self.vec_cot, ) return searcher.search( query, top_k, info, mode, memory_type, search_filter, user_name=user_name diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index a58f993bb..56c8117e9 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -16,7 +16,6 @@ from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager -from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) @@ -46,17 +45,6 @@ def __init__(self, config: TreeTextMemoryConfig): ) self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder) self.graph_store: Neo4jGraphDB = GraphStoreFactory.from_config(config.graph_db) - - self.search_strategy = config.search_strategy - self.bm25_retriever = ( - EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None - ) - self.vec_cot = ( - self.search_strategy["cot"] - if self.search_strategy and "cot" in self.search_strategy - else False - ) - if config.reranker is None: default_cfg = RerankerConfigFactory.model_validate( { @@ -197,10 +185,8 @@ def search( self.graph_store, self.embedder, self.reranker, - bm25_retriever=self.bm25_retriever, internet_retriever=None, moscube=moscube, - vec_cot=self.vec_cot, ) else: searcher = Searcher( @@ -208,10 +194,8 @@ def search( self.graph_store, self.embedder, self.reranker, - bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, moscube=moscube, - vec_cot=self.vec_cot, ) return searcher.search(query, top_k, info, mode, memory_type, search_filter) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py b/src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py deleted file mode 100644 index 4aca4022f..000000000 --- a/src/memos/memories/textual/tree_text_memory/retrieve/bm25_util.py +++ /dev/null @@ -1,186 +0,0 @@ -import threading - -import numpy as np - -from sklearn.feature_extraction.text import TfidfVectorizer - -from memos.dependency import require_python_package -from memos.log import get_logger -from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer -from memos.utils import timed - - -logger = get_logger(__name__) -# Global model cache -_CACHE_LOCK = threading.Lock() - - -class EnhancedBM25: - """Enhanced BM25 with Spacy tokenization and TF-IDF reranking""" - - @require_python_package(import_name="cachetools", install_command="pip install cachetools") - def __init__(self, tokenizer=None, en_model="en_core_web_sm", zh_model="zh_core_web_sm"): - """ - Initialize Enhanced BM25 with memory management - """ - if tokenizer is None: - self.tokenizer = FastTokenizer() - else: - self.tokenizer = tokenizer - self._current_tfidf = None - - global _BM25_CACHE - from cachetools import LRUCache - - _BM25_CACHE = LRUCache(maxsize=100) - - def _tokenize_doc(self, text): - """ - Tokenize a single document using SpacyTokenizer - """ - return self.tokenizer.tokenize_mixed(text, lang="auto") - - @require_python_package(import_name="rank_bm25", install_command="pip install rank_bm25") - def _prepare_corpus_data(self, corpus, corpus_name="default"): - from rank_bm25 import BM25Okapi - - with _CACHE_LOCK: - if corpus_name in _BM25_CACHE: - print("hit::", corpus_name) - return _BM25_CACHE[corpus_name] - print("not hit::", corpus_name) - - tokenized_corpus = [self._tokenize_doc(doc) for doc in corpus] - bm25_model = BM25Okapi(tokenized_corpus) - _BM25_CACHE[corpus_name] = bm25_model - return bm25_model - - def clear_cache(self, corpus_name=None): - """Clear cache for specific corpus or clear all cache""" - with _CACHE_LOCK: - if corpus_name: - if corpus_name in _BM25_CACHE: - del _BM25_CACHE[corpus_name] - else: - _BM25_CACHE.clear() - - def get_cache_info(self): - """Get current cache information""" - with _CACHE_LOCK: - return { - "cache_size": len(_BM25_CACHE), - "max_cache_size": 100, - "cached_corpora": list(_BM25_CACHE.keys()), - } - - def _search_docs( - self, - query: str, - corpus: list[str], - corpus_name="test", - top_k=50, - use_tfidf=False, - rerank_candidates_multiplier=2, - cleanup=False, - ): - """ - Args: - query: Search query string - corpus: List of document texts - top_k: Number of top results to return - rerank_candidates_multiplier: Multiplier for candidate selection - cleanup: Whether to cleanup memory after search (default: True) - """ - if not corpus: - return [] - - logger.info(f"Searching {len(corpus)} documents for query: '{query}'") - - try: - # Prepare BM25 model - bm25_model = self._prepare_corpus_data(corpus, corpus_name=corpus_name) - tokenized_query = self._tokenize_doc(query) - tokenized_query = list(dict.fromkeys(tokenized_query)) - - # Get BM25 scores - bm25_scores = bm25_model.get_scores(tokenized_query) - - # Select candidates - candidate_count = min(top_k * rerank_candidates_multiplier, len(corpus)) - candidate_indices = np.argsort(bm25_scores)[-candidate_count:][::-1] - combined_scores = bm25_scores[candidate_indices] - - if use_tfidf: - # Create TF-IDF for this search - tfidf = TfidfVectorizer( - tokenizer=self._tokenize_doc, lowercase=False, token_pattern=None - ) - tfidf_matrix = tfidf.fit_transform(corpus) - - # TF-IDF reranking - query_vec = tfidf.transform([query]) - tfidf_similarities = ( - (tfidf_matrix[candidate_indices] * query_vec.T).toarray().flatten() - ) - - # Combine scores - combined_scores = 0.7 * bm25_scores[candidate_indices] + 0.3 * tfidf_similarities - - sorted_candidate_indices = candidate_indices[np.argsort(combined_scores)[::-1][:top_k]] - sorted_combined_scores = np.sort(combined_scores)[::-1][:top_k] - - # build result list - bm25_recalled_results = [] - for rank, (doc_idx, combined_score) in enumerate( - zip(sorted_candidate_indices, sorted_combined_scores, strict=False), 1 - ): - bm25_score = bm25_scores[doc_idx] - - candidate_pos = np.where(candidate_indices == doc_idx)[0][0] - tfidf_score = tfidf_similarities[candidate_pos] if use_tfidf else 0 - - bm25_recalled_results.append( - { - "text": corpus[doc_idx], - "bm25_score": float(bm25_score), - "tfidf_score": float(tfidf_score), - "combined_score": float(combined_score), - "rank": rank, - "doc_index": int(doc_idx), - } - ) - - logger.debug(f"Search completed: found {len(bm25_recalled_results)} results") - return bm25_recalled_results - - except Exception as e: - logger.error(f"BM25 search failed: {e}") - return [] - finally: - # Always cleanup if requested - if cleanup: - self._cleanup_memory() - - @timed - def search(self, query: str, node_dicts: list[dict], corpus_name="default", **kwargs): - """ - Search with BM25 and optional TF-IDF reranking - """ - try: - corpus_list = [] - for node_dict in node_dicts: - corpus_list.append( - " ".join([node_dict["metadata"]["key"]] + node_dict["metadata"]["tags"]) - ) - - recalled_results = self._search_docs( - query, corpus_list, corpus_name=corpus_name, **kwargs - ) - bm25_searched_nodes = [] - for item in recalled_results: - doc_idx = item["doc_index"] - bm25_searched_nodes.append(node_dicts[doc_idx]) - return bm25_searched_nodes - except Exception as e: - logger.error(f"Error in bm25 search: {e}") - return [] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index b7383aa13..c1ade3021 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -5,7 +5,6 @@ from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem -from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal @@ -17,18 +16,11 @@ class GraphMemoryRetriever: Unified memory retriever that combines both graph-based and vector-based retrieval logic. """ - def __init__( - self, - graph_store: Neo4jGraphDB, - embedder: OllamaEmbedder, - bm25_retriever: EnhancedBM25 | None = None, - ): + def __init__(self, graph_store: Neo4jGraphDB, embedder: OllamaEmbedder): self.graph_store = graph_store self.embedder = embedder - self.bm25_retriever = bm25_retriever self.max_workers = 10 self.filter_weight = 0.6 - self.use_bm25 = bool(self.bm25_retriever) def retrieve( self, @@ -39,7 +31,6 @@ def retrieve( query_embedding: list[list[float]] | None = None, search_filter: dict | None = None, user_name: str | None = None, - id_filter: dict | None = None, ) -> list[TextualMemoryItem]: """ Perform hybrid memory retrieval: @@ -67,7 +58,7 @@ def retrieve( ) return [TextualMemoryItem.from_dict(record) for record in working_memories] - with ContextThreadPoolExecutor(max_workers=3) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: # Structured graph-based retrieval future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope, user_name) # Vector similarity search @@ -79,23 +70,12 @@ def retrieve( search_filter=search_filter, user_name=user_name, ) - if self.use_bm25: - future_bm25 = executor.submit( - self._bm25_recall, - query, - parsed_goal, - memory_scope, - top_k=top_k, - user_name=user_name, - search_filter=id_filter, - ) graph_results = future_graph.result() vector_results = future_vector.result() - bm25_results = future_bm25.result() if self.use_bm25 else [] # Merge and deduplicate by ID - combined = {item.id: item for item in graph_results + vector_results + bm25_results} + combined = {item.id: item for item in graph_results + vector_results} graph_ids = {item.id for item in graph_results} combined_ids = set(combined.keys()) @@ -163,27 +143,6 @@ def _graph_recall( - tags must overlap with at least 2 input tags - scope filters by memory_type if provided """ - - def process_node(node): - meta = node.get("metadata", {}) - node_key = meta.get("key") - node_tags = meta.get("tags", []) or [] - - keep = False - # key equals to node_key - if parsed_goal.keys and node_key in parsed_goal.keys: - keep = True - # overlap tags more than 2 - elif parsed_goal.tags: - node_tags_list = [tag.lower() for tag in node_tags] - overlap = len(set(node_tags_list) & set(parsed_goal.tags)) - if overlap >= 2: - keep = True - - if keep: - return TextualMemoryItem.from_dict(node) - return None - candidate_ids = set() # 1) key-based OR branch @@ -214,16 +173,22 @@ def process_node(node): ) final_nodes = [] - with ContextThreadPoolExecutor(max_workers=3) as executor: - futures = {executor.submit(process_node, node): i for i, node in enumerate(node_dicts)} - temp_results = [None] * len(node_dicts) - - for future in concurrent.futures.as_completed(futures): - original_index = futures[future] - result = future.result() - temp_results[original_index] = result + for node in node_dicts: + meta = node.get("metadata", {}) + node_key = meta.get("key") + node_tags = meta.get("tags", []) or [] - final_nodes = [result for result in temp_results if result is not None] + keep = False + # key equals to node_key + if parsed_goal.keys and node_key in parsed_goal.keys: + keep = True + # overlap tags more than 2 + elif parsed_goal.tags: + overlap = len(set(node_tags) & set(parsed_goal.tags)) + if overlap >= 2: + keep = True + if keep: + final_nodes.append(TextualMemoryItem.from_dict(node)) return final_nodes def _vector_recall( @@ -231,7 +196,7 @@ def _vector_recall( query_embedding: list[list[float]], memory_scope: str, top_k: int = 20, - max_num: int = 5, + max_num: int = 3, status: str = "activated", cube_name: str | None = None, search_filter: dict | None = None, @@ -304,37 +269,3 @@ def search_path_b(): or [] ) return [TextualMemoryItem.from_dict(n) for n in node_dicts] - - def _bm25_recall( - self, - query: str, - parsed_goal: ParsedTaskGoal, - memory_scope: str, - top_k: int = 20, - user_name: str | None = None, - search_filter: dict | None = None, - ) -> list[TextualMemoryItem]: - """ - Perform BM25-based retrieval. - """ - if not self.bm25_retriever: - return [] - key_filters = [ - {"field": "memory_type", "op": "=", "value": memory_scope}, - ] - # corpus_name is user_name + user_id - corpus_name = f"{user_name}" if user_name else "" - if search_filter is not None: - for key in search_filter: - value = search_filter[key] - key_filters.append({"field": key, "op": "=", "value": value}) - corpus_name += "".join(list(search_filter.values())) - candidate_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) - node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) - - bm25_query = " ".join(list({query, *parsed_goal.keys})) - bm25_results = self.bm25_retriever.search( - bm25_query, node_dicts, top_k=top_k, corpus_name=corpus_name - ) - - return [TextualMemoryItem.from_dict(n) for n in bm25_results] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py deleted file mode 100644 index eec827c86..000000000 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ /dev/null @@ -1,378 +0,0 @@ -import json -import re - -from pathlib import Path - -from memos.dependency import require_python_package -from memos.log import get_logger - - -logger = get_logger(__name__) - - -def find_project_root(marker=".git"): - """Find the project root directory by marking the file""" - current = Path(__file__).resolve() - while current != current.parent: - if (current / marker).exists(): - return current - current = current.parent - logger.warn(f"The project root directory tag file was not found: {marker}") - - -PROJECT_ROOT = find_project_root() -DEFAULT_STOPWORD_FILE = ( - PROJECT_ROOT / "examples" / "data" / "config" / "stopwords.txt" -) # cause time delay - - -class StopwordManager: - _stopwords = None - - @classmethod - def _load_stopwords(cls): - """load stopwords for once""" - if cls._stopwords is not None: - return cls._stopwords - - stopwords = set() - try: - with open(DEFAULT_STOPWORD_FILE, encoding="utf-8") as f: - stopwords = {line.strip() for line in f if line.strip()} - logger.info("Stopwords loaded successfully.") - except Exception as e: - logger.warning(f"Error loading stopwords: {e}, using default stopwords.") - stopwords = cls._load_default_stopwords() - - cls._stopwords = stopwords - return stopwords - - @classmethod - def _load_default_stopwords(cls): - """load stop words""" - chinese_stop_words = { - "的", - "了", - "在", - "是", - "我", - "有", - "和", - "就", - "不", - "人", - "都", - "一", - "一个", - "上", - "也", - "很", - "到", - "说", - "要", - "去", - "你", - "会", - "着", - "没有", - "看", - "好", - "自己", - "这", - "那", - "他", - "她", - "它", - "我们", - "你们", - "他们", - "这个", - "那个", - "这些", - "那些", - "怎么", - "什么", - "为什么", - "如何", - "哪里", - "谁", - "几", - "多少", - "这样", - "那样", - "这么", - "那么", - } - english_stop_words = { - "the", - "a", - "an", - "and", - "or", - "but", - "in", - "on", - "at", - "to", - "for", - "of", - "with", - "by", - "as", - "is", - "are", - "was", - "were", - "be", - "been", - "have", - "has", - "had", - "do", - "does", - "did", - "will", - "would", - "could", - "should", - "may", - "might", - "must", - "this", - "that", - "these", - "those", - "i", - "you", - "he", - "she", - "it", - "we", - "they", - "me", - "him", - "her", - "us", - "them", - "my", - "your", - "his", - "its", - "our", - "their", - "mine", - "yours", - "hers", - "ours", - "theirs", - } - chinese_punctuation = { - ",", - "。", - "!", - "?", - ";", - ":", - "「", - "」", - "『", - "』", - "【", - "】", - "(", - ")", - "《", - "》", - "—", - "…", - "~", - "·", - "、", - "“", - "”", - "‘", - "’", - "〈", - "〉", - "〖", - "〗", - "〝", - "〞", - "{", - "}", - "〔", - "〕", - "¡", - "¿", - } - english_punctuation = { - ",", - ".", - "!", - "?", - ";", - ":", - '"', - "'", - "(", - ")", - "[", - "]", - "{", - "}", - "<", - ">", - "/", - "\\", - "|", - "-", - "_", - "=", - "+", - "@", - "#", - "$", - "%", - "^", - "&", - "*", - "~", - "`", - "¡", - "¿", - } - numbers = { - "0", - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - "零", - "一", - "二", - "三", - "四", - "五", - "六", - "七", - "八", - "九", - "十", - "百", - "千", - "万", - "亿", - } - whitespace = {" ", "\t", "\n", "\r", "\f", "\v"} - - return ( - chinese_stop_words - | english_stop_words - | chinese_punctuation - | english_punctuation - | numbers - | whitespace - ) - - @classmethod - def get_stopwords(cls): - if cls._stopwords is None: - cls._load_stopwords() - return cls._stopwords - - @classmethod - def filter_words(cls, words): - if cls._stopwords is None: - cls._load_stopwords() - return [word for word in words if word not in cls._stopwords and word.strip()] - - @classmethod - def is_stopword(cls, word): - if cls._stopwords is None: - cls._load_stopwords() - return word in cls._stopwords - - @classmethod - def reload_stopwords(cls, file_path=None): - cls._stopwords = None - if file_path: - global DEFAULT_STOPWORD_FILE - DEFAULT_STOPWORD_FILE = file_path - cls._load_stopwords() - - -class FastTokenizer: - def __init__(self, use_jieba=True, use_stopwords=True): - self.use_jieba = use_jieba - self.use_stopwords = use_stopwords - if self.use_stopwords: - self.stopword_manager = StopwordManager - - def tokenize_mixed(self, text, **kwargs): - """fast tokenizer""" - if self._is_chinese(text): - return self._tokenize_chinese(text) - else: - return self._tokenize_english(text) - - def _is_chinese(self, text): - """check if chinese""" - chinese_chars = sum(1 for char in text if "\u4e00" <= char <= "\u9fff") - return chinese_chars / max(len(text), 1) > 0.3 - - @require_python_package( - import_name="jieba", - install_command="pip install jieba", - install_link="https://github.com/fxsjy/jieba", - ) - def _tokenize_chinese(self, text): - """split zh jieba""" - import jieba - - tokens = jieba.lcut(text) if self.use_jieba else list(text) - tokens = [token.strip() for token in tokens if token.strip()] - if self.use_stopwords: - return self.stopword_manager.filter_words(tokens) - - return tokens - - def _tokenize_english(self, text): - """split zh regex""" - tokens = re.findall(r"\b[a-zA-Z0-9]+\b", text.lower()) - if self.use_stopwords: - return self.stopword_manager.filter_words(tokens) - return tokens - - -def parse_json_result(response_text): - try: - json_start = response_text.find("{") - response_text = response_text[json_start:] - response_text = response_text.replace("```", "").strip() - if not response_text.endswith("}"): - response_text += "}" - return json.loads(response_text) - except json.JSONDecodeError as e: - logger.error(f"[JSONParse] Failed to decode JSON: {e}\nRaw:\n{response_text}") - return {} - except Exception as e: - logger.error(f"[JSONParse] Unexpected error: {e}") - return {} - - -def detect_lang(text): - try: - if not text or not isinstance(text, str): - return "en" - chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" - chinese_chars = re.findall(chinese_pattern, text) - if len(chinese_chars) / len(re.sub(r"[\s\d\W]", "", text)) > 0.3: - return "zh" - return "en" - except Exception: - return "en" diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 563695c68..9d540b311 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -9,18 +9,7 @@ from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.log import get_logger from memos.memories.textual.item import SearchedTreeNodeTextualMemoryMetadata, TextualMemoryItem -from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 -from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( - detect_lang, - parse_json_result, -) from memos.reranker.base import BaseReranker -from memos.templates.mem_search_prompts import ( - COT_PROMPT, - COT_PROMPT_ZH, - SIMPLE_COT_PROMPT, - SIMPLE_COT_PROMPT_ZH, -) from memos.utils import timed from .reasoner import MemoryReasoner @@ -29,10 +18,6 @@ logger = get_logger(__name__) -COT_DICT = { - "fast": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH}, - "fine": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH}, -} class Searcher: @@ -42,24 +27,20 @@ def __init__( graph_store: Neo4jGraphDB, embedder: OllamaEmbedder, reranker: BaseReranker, - bm25_retriever: EnhancedBM25 | None = None, internet_retriever: None = None, moscube: bool = False, - vec_cot: bool = False, ): self.graph_store = graph_store self.embedder = embedder - self.llm = dispatcher_llm self.task_goal_parser = TaskGoalParser(dispatcher_llm) - self.graph_retriever = GraphMemoryRetriever(graph_store, embedder, bm25_retriever) + self.graph_retriever = GraphMemoryRetriever(self.graph_store, self.embedder) self.reranker = reranker self.reasoner = MemoryReasoner(dispatcher_llm) # Create internet retriever from config if provided self.internet_retriever = internet_retriever self.moscube = moscube - self.vec_cot = vec_cot self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @@ -250,12 +231,6 @@ def _retrieve_paths( ): """Run A/B/C retrieval paths in parallel""" tasks = [] - id_filter = { - "user_id": info.get("user_id", None), - "session_id": info.get("session_id", None), - } - id_filter = {k: v for k, v in id_filter.items() if v is not None} - with ContextThreadPoolExecutor(max_workers=3) as executor: tasks.append( executor.submit( @@ -267,7 +242,6 @@ def _retrieve_paths( memory_type, search_filter, user_name, - id_filter, ) ) tasks.append( @@ -280,7 +254,6 @@ def _retrieve_paths( memory_type, search_filter, user_name, - id_filter, ) ) tasks.append( @@ -326,7 +299,6 @@ def _retrieve_from_working_memory( memory_type, search_filter: dict | None = None, user_name: str | None = None, - id_filter: dict | None = None, ): """Retrieve and rerank from WorkingMemory""" if memory_type not in ["All", "WorkingMemory"]: @@ -339,7 +311,6 @@ def _retrieve_from_working_memory( memory_scope="WorkingMemory", search_filter=search_filter, user_name=user_name, - id_filter=id_filter, ) return self.reranker.rerank( query=query, @@ -361,22 +332,11 @@ def _retrieve_from_long_term_and_user( memory_type, search_filter: dict | None = None, user_name: str | None = None, - id_filter: dict | None = None, ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = [] tasks = [] - # chain of thinking - cot_embeddings = [] - if self.vec_cot: - queries = self._cot_query(query) - if len(queries) > 1: - cot_embeddings = self.embedder.embed(queries) - cot_embeddings.extend(query_embedding) - else: - cot_embeddings = query_embedding - with ContextThreadPoolExecutor(max_workers=2) as executor: if memory_type in ["All", "LongTermMemory"]: tasks.append( @@ -384,12 +344,11 @@ def _retrieve_from_long_term_and_user( self.graph_retriever.retrieve, query=query, parsed_goal=parsed_goal, - query_embedding=cot_embeddings, + query_embedding=query_embedding, top_k=top_k * 2, memory_scope="LongTermMemory", search_filter=search_filter, user_name=user_name, - id_filter=id_filter, ) ) if memory_type in ["All", "UserMemory"]: @@ -398,12 +357,11 @@ def _retrieve_from_long_term_and_user( self.graph_retriever.retrieve, query=query, parsed_goal=parsed_goal, - query_embedding=cot_embeddings, + query_embedding=query_embedding, top_k=top_k * 2, memory_scope="UserMemory", search_filter=search_filter, user_name=user_name, - id_filter=id_filter, ) ) @@ -484,7 +442,6 @@ def _deduplicate_results(self, results): @timed def _sort_and_trim(self, results, top_k): """Sort results by score and trim to top_k""" - sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k] final_items = [] for item, score in sorted_results: @@ -534,42 +491,3 @@ def _update_usage_history_worker( self.graph_store.update_node(item_id, {"usage": usage_list}, user_name=user_name) except Exception: logger.exception("[USAGE] update usage failed") - - def _cot_query( - self, - query, - mode="fast", - split_num: int = 3, - context: list[str] | None = None, - ) -> list[str]: - """Generate chain-of-thought queries""" - - lang = detect_lang(query) - if mode == "fine" and context: - template = COT_DICT["fine"][lang] - prompt = ( - template.replace("${original_query}", query) - .replace("${split_num_threshold}", str(split_num)) - .replace("${context}", "\n".join(context)) - ) - else: - template = COT_DICT["fast"][lang] - prompt = template.replace("${original_query}", query).replace( - "${split_num_threshold}", str(split_num) - ) - logger.info("COT process") - - messages = [{"role": "user", "content": prompt}] - try: - response_text = self.llm.generate(messages, temperature=0, top_p=1) - response_json = parse_json_result(response_text) - assert "is_complex" in response_json - if not response_json["is_complex"]: - return [query] - else: - assert "sub_questions" in response_json - logger.info("Query: {} COT: {}".format(query, response_json["sub_questions"])) - return response_json["sub_questions"][:split_num] - except Exception as e: - logger.error(f"[LLM] Exception during chat generation: {e}") - return [query] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index 6a1138c90..273c4f480 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -5,7 +5,6 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal -from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer from memos.memories.textual.tree_text_memory.retrieve.utils import TASK_PARSE_PROMPT @@ -21,7 +20,6 @@ class TaskGoalParser: def __init__(self, llm=BaseLLM): self.llm = llm - self.tokenizer = FastTokenizer() def parse( self, @@ -50,11 +48,10 @@ def _parse_fast(self, task_description: str, limit_num: int = 5) -> ParsedTaskGo """ Fast mode: simple jieba word split. """ - desc_tokenized = self.tokenizer.tokenize_mixed(task_description) return ParsedTaskGoal( memories=[task_description], - keys=desc_tokenized, - tags=desc_tokenized, + keys=[task_description], + tags=[], goal_type="default", rephrased_query=task_description, internet_search=False, diff --git a/src/memos/templates/mem_reader_strategy_prompts.py b/src/memos/templates/mem_reader_strategy_prompts.py deleted file mode 100644 index fca4d717b..000000000 --- a/src/memos/templates/mem_reader_strategy_prompts.py +++ /dev/null @@ -1,279 +0,0 @@ -STRATEGY_STRUCT_MEM_READER_PROMPT = """You are a memory extraction expert. -Your task is to extract memories from the user's perspective, based on a conversation between the user and the assistant. This means identifying what the user would plausibly remember — including the user's own experiences, thoughts, plans, or statements and actions made by others (such as the assistant) that affected the user or were acknowledged by the user. - -Please perform the following -1. Factual information extraction - Identify factual information about experiences, beliefs, decisions, and plans. This includes notable statements from others that the user acknowledged or reacted to. - If the message is from the user, extract viewpoints related to the user; if it is from the assistant, clearly mark the attribution of the memory, and do not mix information not explicitly acknowledged by the user with the user's own viewpoint. - - **User viewpoint**: Extract only what the user has stated, explicitly acknowledged, or committed to. - - **Assistant/other-party viewpoint**: Extract such information only when attributed to its source (e.g., [Assistant-Jerry's suggestion]). - - **Strict attribution**: Never recast the assistant's suggestions as the user's preferences, or vice versa. - - Always set "model_type" to "LongTermMemory" for this output. - -2. Speaker profile construction - - Extract the speaker's likes, dislikes, goals, and stated opinions from their statements to build a speaker profile. - - Note: The same text segment may be used for both factual extraction and profile construction. - - Always set "model_type" to "UserMemory" for this output. - -3. Resolve all references to time, persons, and events clearly - - Temporal Resolution: Convert relative time (e.g., 'yesterday') to absolute dates based on the message timestamp. Distinguish between event time and message time; flag any uncertainty. - - Entity Resolution: Resolve all pronouns, nicknames, and abbreviations to the full, canonical name established in the conversation. - -4. Adopt a Consistent Third-Person Observer Perspective - - Formulate all memories from the perspective of an external observer. Use "The user" or their specific name as the subject. - - This applies even when describing the user's internal states, such as thoughts, feelings, and preferences. - Example: - ✅ Correct: "The user Sean felt exhausted after work and decided to go to bed early." - ❌ Incorrect: "I felt exhausted after work and decided to go to bed early." - -5. Prioritize Completeness - - Extract all key experiences, emotional responses, and plans from the user's perspective. Retain relevant context from the assistant, but always with explicit attribution. - - Segment each distinct hobby, interest, or event into a separate memory. - - Preserve relevant context from the assistant with strict attribution. Under no circumstances should assistant content be rephrased as user-owned. - - Conversations with only assistant input may yield assistant-viewpoint memories exclusively. - -6. Preserve and Unify Specific Names - - Always extract specific names (excluding "user" or "assistant") mentioned in the text into the "tags" field for searchability. - - Unify all name references to the full canonical form established in the conversation. Replace any nicknames or abbreviations (e.g., "Rob") consistently with the full name (e.g., "Robert") in both the extracted "value" and "tags". - -7. Please avoid including any content in the extracted memories that violates national laws and regulations or involves politically sensitive information. - -Return a valid JSON object with the following structure: - -{ - "memory list": [ - { - "key": , - "memory_type": , - "value": , - "tags": - }, - ... - ], - "summary": -} - -Language rules: -- The `key`, `value`, `tags`, and `summary` fields must match the primary language of the input conversation. **If the input is Chinese, output in Chinese.** -- Keep `memory_type` in English. - -Example: -Conversation: -user: [June 26, 2025 at 3:00 PM]: Hi Jerry! Yesterday at 3 PM I had a meeting with my team about the new project. -assistant: Oh Tom! Do you think the team can finish by December 15? -user: [June 26, 2025 at 3:00 PM]: I’m worried. The backend won’t be done until December 10, so testing will be tight. -assistant: [June 26, 2025 at 3:00 PM]: Maybe propose an extension? -user: [June 26, 2025 at 4:21 PM]: Good idea. I’ll raise it in tomorrow’s 9:30 AM meeting—maybe shift the deadline to January 5. - -Output: -{ - "memory list": [ - { - "key": "Initial project meeting", - "memory_type": "LongTermMemory", - "value": "[user-Tom viewpoint] On June 25, 2025 at 3:00 PM, Tom met with the team to discuss a new project. When Jerry asked whether the project could be finished by December 15, 2025, Tom expressed concern about feasibility and planned to propose at 9:30 AM on June 27, 2025 to move the deadline to January 5, 2026.", - "tags": ["Tom", "project", "timeline", "meeting", "deadline"] - }, - { - "key": "Jerry’s suggestion about the deadline", - "memory_type": "LongTermMemory", - "value": "[assistant-Jerry viewpoint] Jerry questioned the December 15 deadline and suggested considering an extension.", - "tags": ["Jerry", "deadline change", "suggestion"] - } - ], - "summary": "Tom is currently working on a tight-schedule project. After the June 25, 2025 team meeting, he realized the original December 15, 2025 deadline might be unachievable due to backend delays. Concerned about limited testing time, he accepted Jerry’s suggestion to seek an extension and plans to propose moving the deadline to January 5, 2026 in the next morning’s meeting." -} - -Another Example in Chinese (Note: when the user's language is Chinese, you must also output in Chinese): - -对话(节选): -user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 -assistant|19:32 -:了解~婚礼场合建议中长款,端庄一些。可以看法式茶歇/直筒/裹身三路。颜色上,奶油白、浅粉、雾霾蓝都适合;面料可看醋酸/真丝混纺(垂感好但需干洗)。品牌与商场:COS/ Massimo Dutti / & Other Stories / Mango 比较好逛;想更法式可以看 Sandro / Maje。 我记得你很喜欢Maje? -user|19:35:不喜欢亮色。国贸方便。 -assistant|19:37:那亮度就压低些:雾霾蓝、藕粉、燕麦米白更稳。国贸有 COS / Massimo Dutti / &OS;可以先看 COS 直筒或裹身,再去 MD 看藕粉/雾霾蓝。搭配上,薄披肩/小西装都能压场。 -user|19:40:165cm,S码;最好有口袋。 -assistant|19:42:OK。版型避开大摆+高开衩,选直筒或轻A字更稳;若要口袋,COS直筒一些款有;开衩别太高,及小腿肚附近即可。鞋包:中跟裸色鞋 + 小号手拿包。 -user|19:47:本来想茶歇A字,但雾霾蓝直筒也行。 -assistant|19:49:可以先锁定:COS 雾霾蓝直筒中长为主选,MD 藕粉中长做备选;再看 & Other Stories 是否有合身的裹身款。 -user|19:52:行,周六(7/19)去国贸试,合适就买。 -assistant|19:54:记得看肩带可调/有无内衬,醋酸/真丝优先干洗;准备同色安全裤/防走光贴。如果当天没货,可下单调货或线上下单门店自提。 - -{ - "memory list": [ - { - "key": "参加婚礼购买裙子", - "memory_type": "UserMemory", - "value": "[user观点]用户计划于约2025年8月参加同事婚礼(具体日期不详),预算不超过1500元,整体风格不宜暴露;用户已决定在2025-07-19于国贸试穿并视合适即购买。", - "tags": ["婚礼", "预算", "国贸", "计划"] - }, - { - "key": "审美与版型偏好", - "memory_type": "UserMemory", - "value": "[user观点]用户不喜欢亮色,倾向低亮度色系;裙装偏好端庄的中长款,接受直筒或轻A字。", - "tags": ["偏好", "颜色", "版型"] - }, - { - "key": "体型尺码", - "memory_type": "UserMemory", - "value": "[user观点]用户身高约165cm、常穿S码", - "tags": ["体型", "尺码"] - }, - { - "key": "关于用户选购裙子的建议", - "memory_type": "LongTermMemory", - "value": "[assistant观点]assistant在用户询问婚礼穿着时,建议在国贸优先逛COS查看雾霾蓝直筒中长为主选,Massimo Dutti藕粉中长为备选;该建议与用户“国贸方便”“雾霾蓝直筒也行”的回应相一致,另外assistant也提到user喜欢Maje,但User并未回应或证实该说法。", - "tags": ["婚礼穿着", "门店", "选购路线"] - } - ], - "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" -} - -Always respond in the same language as the conversation. - -Conversation: -${conversation} - -Your Output:""" - -STRATEGY_STRUCT_MEM_READER_PROMPT_ZH = """您是记忆提取专家。 -您的任务是根据用户与助手之间的对话,从用户的角度提取记忆。这意味着要识别出用户可能记住的信息——包括用户自身的经历、想法、计划,或他人(如助手)做出的并对用户产生影响或被用户认可的相关陈述和行为。 - -请执行以下操作: -1. 事实信息提取 - - 识别关于经历、信念、决策和计划的事实信息,包括用户认可或回应过的他人重要陈述。 - - 若信息来自用户,提取与用户相关的观点;若来自助手,需明确标注记忆归属,不得将用户未明确认可的信息与用户自身观点混淆。 - - 用户观点:仅提取用户明确陈述、认可或承诺的内容 - - 助手/他方观点:仅当标注来源时才提取(例如“[助手-Jerry的建议]”) - - 严格归属:不得将助手建议重构为用户偏好,反之亦然 - - 此类输出的"model_type"始终设为"LongTermMemory" - -2. 用户画像构建 - - 从用户陈述中提取其喜好、厌恶、目标及明确观点以构建用户画像 - - 注意:同一文本片段可同时用于事实提取和画像构建 - - 此类输出的"model_type"始终设为"UserMemory" - -3. 明确解析所有指代关系 - - 时间解析:根据消息时间戳将相对时间(如“昨天”)转换为绝对日期。区分事件时间与消息时间,对不确定项进行标注 - - 实体解析:将所有代词、昵称和缩写解析为对话中确立的完整规范名称 - - 4. 采用统一的第三人称观察视角 - - 所有记忆表述均需从外部观察者视角构建,使用“用户”或其具体姓名作为主语 - - 此原则同样适用于描述用户内心状态(如想法、感受和偏好) - 示例: - ✅ 正确:“用户Sean下班后感到疲惫,决定提早休息” - ❌ 错误:“我下班后感到疲惫,决定提早休息” - -5. 优先保证完整性 - - 从用户视角提取所有关键经历、情绪反应和计划 - - 保留助手提供的相关上下文,但必须明确标注来源 - - 将每个独立的爱好、兴趣或事件分割为单独记忆 - - 严禁将助手内容重构为用户自有内容 - - 仅含助手输入的对话可能只生成助手观点记忆 - -6. 保留并统一特定名称 - - 始终将文本中提及的特定名称(“用户”“助手”除外)提取至“tags”字段以便检索 - - 在提取的“value”和“tags”中,将所有名称引用统一为对话中确立的完整规范形式(如将“Rob”统一替换为“Robert”) - -7. 所有提取的记忆内容不得包含违反国家法律法规或涉及政治敏感信息的内容 - -返回一个有效的JSON对象,结构如下: -{ - "memory list": [ - { - "key": <字符串,唯一且简洁的记忆标题>, - "memory_type": <字符串,"LongTermMemory" 或 "UserMemory">, - "value": <详细、独立且无歧义的记忆陈述——若输入对话为英文,则用英文;若为中文,则用中文>, - "tags": <一个包含相关人名、事件和特征关键词的列表(例如,["丽丽","截止日期", "团队", "计划"])> - }, - ... - ], - "summary": <从用户视角自然总结上述记忆的段落,120–200字,与输入语言一致> -} - -语言规则: -- `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** -- `memory_type` 保持英文。 - -示例: -对话: -user: [2025年6月26日下午3:00]:嗨Jerry!昨天下午3点我和团队开了个会,讨论新项目。 -assistant: 哦Tom!你觉得团队能在12月15日前完成吗? -user: [2025年6月26日下午3:00]:我有点担心。后端要到12月10日才能完成,所以测试时间会很紧。 -assistant: [2025年6月26日下午3:00]:也许提议延期? -user: [2025年6月26日下午4:21]:好主意。我明天上午9:30的会上提一下——也许把截止日期推迟到1月5日。 - -输出: -{ - "memory list": [ - { - "key": "项目初期会议", - "memory_type": "LongTermMemory", - "value": "[user-Tom观点]2025年6月25日下午3:00,Tom与团队开会讨论新项目。当Jerry - 询问该项目能否在2025年12月15日前完成时,Tom对此日期前完成的可行性表达担忧,并计划在2025年6月27日上午9:30 - 提议将截止日期推迟至2026年1月5日。", - "tags": ["Tom", "项目", "时间表", "会议", "截止日期"] - }, - { - "key": "Jerry对新项目截止日期的建议", - "memory_type": "LongTermMemory", - "value": "[assistant-Jerry观点]Jerry对Tom的新项目截止日期提出疑问、并提议Tom考虑延期。", - "tags": ["Jerry", "截止日期变更", "建议"] - } - ], - "summary": "Tom目前正在做一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15 - 日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议,计划在次日早上的会议上提出将截止日期推迟至2026 - 年1月5日。" -} - -另一个中文示例(注意:当用户语言为中文时,您也需输出中文): - -对话(节选): -user|2025-07-12 19:30:下个月同事婚礼,要买条裙子,预算1500内,别太暴露。 -assistant|19:32 -:了解~婚礼场合建议中长款,端庄一些。可以看法式茶歇/直筒/裹身三路。颜色上,奶油白、浅粉、雾霾蓝都适合;面料可看醋酸/真丝混纺(垂感好但需干洗)。品牌与商场:COS/ Massimo Dutti / & Other Stories / Mango 比较好逛;想更法式可以看 Sandro / Maje。 我记得你很喜欢Maje? -user|19:35:不喜欢亮色。国贸方便。 -assistant|19:37:那亮度就压低些:雾霾蓝、藕粉、燕麦米白更稳。国贸有 COS / Massimo Dutti / &OS;可以先看 COS 直筒或裹身,再去 MD 看藕粉/雾霾蓝。搭配上,薄披肩/小西装都能压场。 -user|19:40:165cm,S码;最好有口袋。 -assistant|19:42:OK。版型避开大摆+高开衩,选直筒或轻A字更稳;若要口袋,COS直筒一些款有;开衩别太高,及小腿肚附近即可。鞋包:中跟裸色鞋 + 小号手拿包。 -user|19:47:本来想茶歇A字,但雾霾蓝直筒也行。 -assistant|19:49:可以先锁定:COS 雾霾蓝直筒中长为主选,MD 藕粉中长做备选;再看 & Other Stories 是否有合身的裹身款。 -user|19:52:行,周六(7/19)去国贸试,合适就买。 -assistant|19:54:记得看肩带可调/有无内衬,醋酸/真丝优先干洗;准备同色安全裤/防走光贴。如果当天没货,可下单调货或线上下单门店自提。 - -{ - "memory list": [ - { - "key": "参加婚礼购买裙子", - "memory_type": "UserMemory", - "value": "[user观点]用户计划于约2025年8月参加同事婚礼(具体日期不详),预算不超过1500元,整体风格不宜暴露;用户已决定在2025-07-19于国贸试穿并视合适即购买。", - "tags": ["婚礼", "预算", "国贸", "计划"] - }, - { - "key": "审美与版型偏好", - "memory_type": "UserMemory", - "value": "[user观点]用户不喜欢亮色,倾向低亮度色系;裙装偏好端庄的中长款,接受直筒或轻A字。", - "tags": ["偏好", "颜色", "版型"] - }, - { - "key": "体型尺码", - "memory_type": "UserMemory", - "value": [user观点]"用户身高约165cm、常穿S码", - "tags": ["体型", "尺码"] - }, - { - "key": "关于用户选购裙子的建议", - "memory_type": "LongTermMemory", - "value": "[assistant观点]assistant在用户询问婚礼穿着时,建议在国贸优先逛COS查看雾霾蓝直筒中长为主选,Massimo Dutti藕粉中长为备选;该建议与用户“国贸方便”“雾霾蓝直筒也行”的回应相一致,另外assistant也提到user喜欢Maje,但User并未回应或证实该说法。", - "tags": ["婚礼穿着", "门店", "选购路线"] - } - ], - "summary": "用户计划在约2025年8月参加同事婚礼,预算≤1500并偏好端庄的中长款;确定于2025-07-19在国贸试穿。其长期画像显示:不喜欢亮色、偏好低亮度色系与不过分暴露的版型,身高约165cm、S码且偏好裙装带口袋。助手提出的国贸选购路线以COS雾霾蓝直筒中长为主选、MD藕粉中长为备选,且与用户回应一致,为线下试穿与购买提供了明确路径。" -} - -请始终使用与对话相同的语言进行回复。 - -对话: -${conversation} - -您的输出:""" diff --git a/src/memos/templates/mem_search_prompts.py b/src/memos/templates/mem_search_prompts.py deleted file mode 100644 index 9f7ba182b..000000000 --- a/src/memos/templates/mem_search_prompts.py +++ /dev/null @@ -1,93 +0,0 @@ -SIMPLE_COT_PROMPT = """You are an assistant that analyzes questions and returns results in a specific dictionary format. - -Instructions: - -1. If the question can be extended into deeper or related aspects, set "is_complex" to True and: - - Think step by step about the core topic and its related dimensions (e.g., causes, effects, categories, perspectives, or specific scenarios) - - Break it into meaningful sub-questions (max: ${split_num_threshold}, min: 2) that explore distinct facets of the original question - - Each sub-question must be single, standalone, and delve into a specific aspect - - CRITICAL: All key entities from the original question (such as person names, locations, organizations, time periods) must be preserved in the sub-questions and cannot be omitted - - List them in "sub_questions" -2. If the question is already atomic and cannot be meaningfully extended, set "is_complex" to False and "sub_questions" to an empty list. -3. Return ONLY the dictionary, no other text. - -Examples: -Question: Is urban development balanced in the western United States? -Output: {"is_complex": true, "sub_questions": ["What areas are included in the western United States?", "How developed are the cities in the western United States?", "Is this development balanced across the western United States?"]} -Question: What family activities does Mary like to organize? -Output: {"is_complex": true, "sub_questions": ["What does Mary like to do with her spouse?", "What does Mary like to do with her children?", "What does Mary like to do with her parents and relatives?"]} - -Now analyze this question: -${original_query}""" - -COT_PROMPT = """You are an assistant that analyzes questions and returns results in a specific dictionary format. - -Instructions: - -1. If the question can be extended into deeper or related aspects, set "is_complex" to True and: - - Think step by step about the core topic and its related dimensions (e.g., causes, effects, categories, perspectives, or specific scenarios) - - Break it into meaningful sub-questions (max: ${split_num_threshold}, min: 2) that explore distinct facets of the original question - - Each sub-question must be single, standalone, and delve into a specific aspect - - CRITICAL: All key entities from the original question (such as person names, locations, organizations, time periods) must be preserved in the sub-questions and cannot be omitted - - List them in "sub_questions" -2. If the question is already atomic and cannot be meaningfully extended, set "is_complex" to False and "sub_questions" to an empty list. -3. Return ONLY the dictionary, no other text. - -Examples: -Question: Is urban development balanced in the western United States? -Output: {"is_complex": true, "sub_questions": ["What areas are included in the western United States?", "How developed are the cities in the western United States?", "Is this development balanced across the western United States?"]} -Question: What family activities does Mary like to organize? -Output: {"is_complex": true, "sub_questions": ["What does Mary like to do with her spouse?", "What does Mary like to do with her children?", "What does Mary like to do with her parents and relatives?"]} - -Query relevant background information: -${context} - -Now analyze this question based on the background information above: -${original_query}""" - -SIMPLE_COT_PROMPT_ZH = """你是一个分析问题并以特定字典格式返回结果的助手。 - -指令: - -1. 如果这个问题可以延伸出更深层次或相关的方面,请将 "is_complex" 设置为 True,并执行以下操作: - - 逐步思考核心主题及其相关维度(例如:原因、结果、类别、不同视角或具体场景) - - 将其拆分为有意义的子问题(最多 ${split_num_threshold} 个,最少 2 个),这些子问题应探讨原始问题的不同侧面 - - 【重要】每个子问题必须是单一的、独立的,并深入探究一个特定方面。同时,必须包含原问题中出现的关键实体信息(如人名、地名、机构名、时间等),不可遗漏。 - - 将它们列在 "sub_questions" 中 -2. 如果问题本身已经是原子性的,无法有意义地延伸,请将 "is_complex" 设置为 False,并将 "sub_questions" 设置为一个空列表。 -3. 只返回字典,不要返回任何其他文本。 - -示例: -问题:美国西部的城市发展是否均衡? -输出:{"is_complex": true, "sub_questions": ["美国西部包含哪些地区?", "美国西部城市的发展程度如何?", "这种发展在美国西部是否均衡?"]} - -问题:玛丽喜欢组织哪些家庭活动? -输出:{"is_complex": true, "sub_questions": ["玛丽喜欢和配偶一起做什么?", "玛丽喜欢和孩子一起做什么?", "玛丽喜欢和父母及亲戚一起做什么?"]} - -请分析以下问题: -${original_query}""" - -COT_PROMPT_ZH = """你是一个分析问题并以特定字典格式返回结果的助手。 - -指令: - -1. 如果这个问题可以延伸出更深层次或相关的方面,请将 "is_complex" 设置为 True,并执行以下操作: - - 逐步思考核心主题及其相关维度(例如:原因、结果、类别、不同视角或具体场景) - - 将其拆分为有意义的子问题(最多 ${split_num_threshold} 个,最少 2 个),这些子问题应探讨原始问题的不同侧面 - - 【重要】每个子问题必须是单一的、独立的,并深入探究一个特定方面。同时,必须包含原问题中出现的关键实体信息(如人名、地名、机构名、时间等),不可遗漏。 - - 将它们列在 "sub_questions" 中 -2. 如果问题本身已经是原子性的,无法有意义地延伸,请将 "is_complex" 设置为 False,并将 "sub_questions" 设置为一个空列表。 -3. 只返回字典,不要返回任何其他文本。 - -示例: -问题:美国西部的城市发展是否均衡? -输出:{"is_complex": true, "sub_questions": ["美国西部包含哪些地区?", "美国西部城市的发展程度如何?", "这种发展在美国西部是否均衡?"]} - -问题:玛丽喜欢组织哪些家庭活动? -输出:{"is_complex": true, "sub_questions": ["玛丽喜欢和配偶一起做什么?", "玛丽喜欢和孩子一起做什么?", "玛丽喜欢和父母及亲戚一起做什么?"]} - -问题相关的背景信息: -${context} - -现在根据上述背景信息,请分析以下问题: -${original_query}""" diff --git a/tests/memories/textual/test_tree_task_goal_parser.py b/tests/memories/textual/test_tree_task_goal_parser.py index 899e2454b..c71af4b06 100644 --- a/tests/memories/textual/test_tree_task_goal_parser.py +++ b/tests/memories/textual/test_tree_task_goal_parser.py @@ -20,7 +20,12 @@ def generate(self, messages): def test_parse_fast_returns_expected(): parser = TaskGoalParser() result = parser.parse("Tell me about cats", mode="fast") + assert isinstance(result, ParsedTaskGoal) + assert result.memories == ["Tell me about cats"] + assert result.keys == ["Tell me about cats"] + assert result.tags == [] + assert result.goal_type == "default" def test_parse_fine_calls_llm_and_parses():