Skip to content

Commit a375911

Browse files
whipser030黑布林CaralHsi
authored
Feat: add recall strategy (#414)
* update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test --------- Co-authored-by: 黑布林 <[email protected]> Co-authored-by: CaralHsi <[email protected]>
1 parent 5923001 commit a375911

File tree

18 files changed

+1393
-41
lines changed

18 files changed

+1393
-41
lines changed

poetry.lock

Lines changed: 47 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ all = [
107107
"markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)",
108108
"pymilvus (>=2.6.1,<3.0.0)",
109109
"datasketch (>=1.6.5,<2.0.0)",
110-
110+
"jieba (>=0.38.1,<0.42.1)",
111+
"rank-bm25 (>=0.2.2)",
112+
"cachetools (>=6.0.0)",
111113
# NOT exist in the above optional groups
112114
# Because they are either huge-size dependencies or infrequently used dependencies.
113115
# We kindof don't want users to install them.

src/memos/api/config.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,17 +419,31 @@ def get_embedder_config() -> dict[str, Any]:
419419
},
420420
}
421421

422+
@staticmethod
423+
def get_reader_config() -> dict[str, Any]:
424+
"""Get reader configuration."""
425+
return {
426+
"backend": os.getenv("MEM_READER_BACKEND", "simple_struct"),
427+
"config": {
428+
"chunk_type": os.getenv("MEM_READER_CHAT_CHUNK_TYPE", "default"),
429+
"chunk_length": int(os.getenv("MEM_READER_CHAT_CHUNK_TOKEN_SIZE", 1600)),
430+
"chunk_session": int(os.getenv("MEM_READER_CHAT_CHUNK_SESS_SIZE", 20)),
431+
"chunk_overlap": int(os.getenv("MEM_READER_CHAT_CHUNK_OVERLAP", 2)),
432+
},
433+
}
434+
422435
@staticmethod
423436
def get_internet_config() -> dict[str, Any]:
424437
"""Get embedder configuration."""
438+
reader_config = APIConfig.get_reader_config()
425439
return {
426440
"backend": "bocha",
427441
"config": {
428442
"api_key": os.getenv("BOCHA_API_KEY"),
429443
"max_results": 15,
430444
"num_per_request": 10,
431445
"reader": {
432-
"backend": "simple_struct",
446+
"backend": reader_config["backend"],
433447
"config": {
434448
"llm": {
435449
"backend": "openai",
@@ -455,6 +469,7 @@ def get_internet_config() -> dict[str, Any]:
455469
"min_sentences_per_chunk": 1,
456470
},
457471
},
472+
"chat_chunker": reader_config,
458473
},
459474
},
460475
},
@@ -656,6 +671,8 @@ def get_product_default_config() -> dict[str, Any]:
656671
openai_config = APIConfig.get_openai_config()
657672
qwen_config = APIConfig.qwen_config()
658673
vllm_config = APIConfig.vllm_config()
674+
reader_config = APIConfig.get_reader_config()
675+
659676
backend_model = {
660677
"openai": openai_config,
661678
"huggingface": qwen_config,
@@ -667,7 +684,7 @@ def get_product_default_config() -> dict[str, Any]:
667684
"user_id": os.getenv("MOS_USER_ID", "root"),
668685
"chat_model": {"backend": backend, "config": backend_model[backend]},
669686
"mem_reader": {
670-
"backend": "simple_struct",
687+
"backend": reader_config["backend"],
671688
"config": {
672689
"llm": APIConfig.get_memreader_config(),
673690
"embedder": APIConfig.get_embedder_config(),
@@ -680,6 +697,7 @@ def get_product_default_config() -> dict[str, Any]:
680697
"min_sentences_per_chunk": 1,
681698
},
682699
},
700+
"chat_chunker": reader_config,
683701
},
684702
},
685703
"enable_textual_memory": True,
@@ -750,6 +768,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
750768
qwen_config = APIConfig.qwen_config()
751769
vllm_config = APIConfig.vllm_config()
752770
mysql_config = APIConfig.get_mysql_config()
771+
reader_config = APIConfig.get_reader_config()
753772
backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai")
754773
backend_model = {
755774
"openai": openai_config,
@@ -764,7 +783,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
764783
"config": backend_model[backend],
765784
},
766785
"mem_reader": {
767-
"backend": "simple_struct",
786+
"backend": reader_config["backend"],
768787
"config": {
769788
"llm": APIConfig.get_memreader_config(),
770789
"embedder": APIConfig.get_embedder_config(),
@@ -777,6 +796,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
777796
"min_sentences_per_chunk": 1,
778797
},
779798
},
799+
"chat_chunker": reader_config,
780800
},
781801
},
782802
"enable_textual_memory": True,
@@ -845,6 +865,10 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
845865
"LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6),
846866
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
847867
},
868+
"search_strategy": {
869+
"bm25": bool(os.getenv("BM25_CALL", "false") == "true"),
870+
"cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"),
871+
},
848872
},
849873
},
850874
"act_mem": {}
@@ -912,6 +936,10 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
912936
"LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6),
913937
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
914938
},
939+
"search_strategy": {
940+
"bm25": bool(os.getenv("BM25_CALL", "false") == "true"),
941+
"cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"),
942+
},
915943
"mode": os.getenv("ASYNC_MODE", "sync"),
916944
},
917945
},

src/memos/configs/mem_reader.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,19 @@ def parse_datetime(cls, value):
3636
description="whether remove example in memory extraction prompt to save token",
3737
)
3838

39+
chat_chunker: dict[str, Any] = Field(
40+
default=None, description="Configuration for the MemReader chat chunk strategy"
41+
)
42+
3943

4044
class SimpleStructMemReaderConfig(BaseMemReaderConfig):
4145
"""SimpleStruct MemReader configuration class."""
4246

4347

48+
class StrategyStructMemReaderConfig(BaseMemReaderConfig):
49+
"""StrategyStruct MemReader configuration class."""
50+
51+
4452
class MemReaderConfigFactory(BaseConfig):
4553
"""Factory class for creating MemReader configurations."""
4654

@@ -49,6 +57,7 @@ class MemReaderConfigFactory(BaseConfig):
4957

5058
backend_to_class: ClassVar[dict[str, Any]] = {
5159
"simple_struct": SimpleStructMemReaderConfig,
60+
"strategy_struct": StrategyStructMemReaderConfig,
5261
}
5362

5463
@field_validator("backend")

src/memos/configs/memory.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,13 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig):
184184
),
185185
)
186186

187+
search_strategy: dict[str, bool] | None = Field(
188+
default=None,
189+
description=(
190+
'Set search strategy for this memory configuration.{"bm25": true, "cot": false}'
191+
),
192+
)
193+
187194
mode: str | None = Field(
188195
default="sync",
189196
description=("whether use asynchronous mode in memory add"),

src/memos/llms/openai.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,18 @@ def clear_cache(cls):
5858
logger.info("OpenAI LLM instance cache cleared")
5959

6060
@timed(log=True, log_prefix="OpenAI LLM")
61-
def generate(self, messages: MessageList) -> str:
62-
"""Generate a response from OpenAI LLM."""
61+
def generate(self, messages: MessageList, **kwargs) -> str:
62+
"""Generate a response from OpenAI LLM, optionally overriding generation params."""
63+
temperature = kwargs.get("temperature", self.config.temperature)
64+
max_tokens = kwargs.get("max_tokens", self.config.max_tokens)
65+
top_p = kwargs.get("top_p", self.config.top_p)
6366
response = self.client.chat.completions.create(
6467
model=self.config.model_name_or_path,
6568
messages=messages,
6669
extra_body=self.config.extra_body,
67-
temperature=self.config.temperature,
68-
max_tokens=self.config.max_tokens,
69-
top_p=self.config.top_p,
70+
temperature=temperature,
71+
max_tokens=max_tokens,
72+
top_p=top_p,
7073
)
7174
logger.info(f"Response from OpenAI: {response.model_dump_json()}")
7275
response_content = response.choices[0].message.content

src/memos/mem_reader/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from memos.configs.mem_reader import MemReaderConfigFactory
44
from memos.mem_reader.base import BaseMemReader
55
from memos.mem_reader.simple_struct import SimpleStructMemReader
6+
from memos.mem_reader.strategy_struct import StrategyStructMemReader
67
from memos.memos_tools.singleton import singleton_factory
78

89

@@ -11,6 +12,7 @@ class MemReaderFactory(BaseMemReader):
1112

1213
backend_to_class: ClassVar[dict[str, Any]] = {
1314
"simple_struct": SimpleStructMemReader,
15+
"strategy_struct": StrategyStructMemReader,
1416
}
1517

1618
@classmethod

0 commit comments

Comments
 (0)