Skip to content

Commit f6e96d5

Browse files
fridayLCaralHsi
andauthored
Feat: add reranker strategies and update configs (#390)
* feat:change reranking source filed * fix: code ci * feat: add reranker strategy * fix: code suffix * fix: code suffix * fix:change strategy name * fix: code format * feat: update memory strategies * fix: code ci --------- Co-authored-by: CaralHsi <[email protected]>
1 parent 0b2b6ed commit f6e96d5

File tree

15 files changed

+893
-21
lines changed

15 files changed

+893
-21
lines changed

src/memos/api/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,16 @@ def get_reranker_config() -> dict[str, Any]:
132132
"""Get embedder configuration."""
133133
embedder_backend = os.getenv("MOS_RERANKER_BACKEND", "http_bge")
134134

135-
if embedder_backend == "http_bge":
135+
if embedder_backend in ["http_bge", "http_bge_strategy"]:
136136
return {
137-
"backend": "http_bge",
137+
"backend": embedder_backend,
138138
"config": {
139139
"url": os.getenv("MOS_RERANKER_URL"),
140140
"model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"),
141141
"timeout": 10,
142142
"headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"),
143143
"rerank_source": os.getenv("MOS_RERANK_SOURCE"),
144+
"reranker_strategy": os.getenv("MOS_RERANKER_STRATEGY", "single_turn"),
144145
},
145146
}
146147
else:

src/memos/graph_dbs/neo4j.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def remove_oldest_memory(
157157
"""
158158
if not self.config.use_multi_db and (self.config.user_name or user_name):
159159
query += f"\nAND n.user_name = '{user_name}'"
160-
160+
keep_latest = int(keep_latest)
161161
query += f"""
162162
WITH n ORDER BY n.updated_at DESC
163163
SKIP {keep_latest}

src/memos/reranker/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ class BaseReranker(ABC):
1616
def rerank(
1717
self,
1818
query: str,
19-
graph_results: list,
19+
graph_results: list[TextualMemoryItem],
2020
top_k: int,
21+
search_filter: dict | None = None,
2122
**kwargs,
2223
) -> list[tuple[TextualMemoryItem, float]]:
2324
"""Return top_k (item, score) sorted by score desc."""

src/memos/reranker/concat.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,49 @@
22

33
from typing import Any
44

5+
from memos.memories.textual.item import SourceMessage
6+
57

68
_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*")
79

810

11+
def get_encoded_tokens(content: str) -> int:
12+
"""
13+
Get encoded tokens.
14+
Args:
15+
content: str
16+
Returns:
17+
int: Encoded tokens.
18+
"""
19+
return len(content)
20+
21+
22+
def truncate_data(data: list[str | dict[str, Any] | Any], max_tokens: int) -> list[str]:
23+
"""
24+
Truncate data to max tokens.
25+
Args:
26+
data: List of strings or dictionaries.
27+
max_tokens: Maximum number of tokens.
28+
Returns:
29+
str: Truncated string.
30+
"""
31+
truncated_string = ""
32+
for item in data:
33+
if isinstance(item, SourceMessage):
34+
content = getattr(item, "content", "")
35+
chat_time = getattr(item, "chat_time", "")
36+
if not content:
37+
continue
38+
truncated_string += f"[{chat_time}]: {content}\n"
39+
if get_encoded_tokens(truncated_string) > max_tokens:
40+
break
41+
return truncated_string
42+
43+
944
def process_source(
10-
items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None, recent_num: int = 3
45+
items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None,
46+
recent_num: int = 10,
47+
max_tokens: int = 2048,
1148
) -> str:
1249
"""
1350
Args:
@@ -23,19 +60,16 @@ def process_source(
2360
memory = None
2461
for item in items:
2562
memory, source = item
26-
for content in source:
27-
if isinstance(content, str):
28-
if "assistant:" in content:
29-
continue
30-
concat_data.append(content)
63+
concat_data.extend(source[-recent_num:])
64+
truncated_string = truncate_data(concat_data, max_tokens)
3165
if memory is not None:
32-
concat_data = [memory, *concat_data]
33-
return "\n".join(concat_data)
66+
truncated_string = f"{memory}\n{truncated_string}"
67+
return truncated_string
3468

3569

3670
def concat_original_source(
3771
graph_results: list,
38-
merge_field: list[str] | None = None,
72+
rerank_source: str | None = None,
3973
) -> list[str]:
4074
"""
4175
Merge memory items with original dialogue.
@@ -45,14 +79,16 @@ def concat_original_source(
4579
Returns:
4680
list[str]: List of memory and concat orginal memory.
4781
"""
48-
if merge_field is None:
49-
merge_field = ["sources"]
82+
merge_field = []
83+
merge_field = ["sources"] if rerank_source is None else rerank_source.split(",")
5084
documents = []
5185
for item in graph_results:
5286
memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m
5387
sources = []
5488
for field in merge_field:
55-
source = getattr(item.metadata, field, "")
89+
source = getattr(item.metadata, field, None)
90+
if source is None:
91+
continue
5692
sources.append((memory, source))
5793
concat_string = process_source(sources)
5894
documents.append(concat_string)

src/memos/reranker/cosine_local.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from typing import TYPE_CHECKING
55

6+
from memos.log import get_logger
7+
68
from .base import BaseReranker
79

810

@@ -16,6 +18,8 @@
1618
except Exception:
1719
_HAS_NUMPY = False
1820

21+
logger = get_logger(__name__)
22+
1923

2024
def _cosine_one_to_many(q: list[float], m: list[list[float]]) -> list[float]:
2125
"""
@@ -92,5 +96,5 @@ def get_weight(it: TextualMemoryItem) -> float:
9296
chosen = {it.id for it, _ in top_items}
9397
remain = [(it, -1.0) for it in graph_results if it.id not in chosen]
9498
top_items.extend(remain[: top_k - len(top_items)])
95-
99+
logger.info(f"CosineLocalReranker rerank result: {top_items[:1]}")
96100
return top_items

src/memos/reranker/factory.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from .cosine_local import CosineLocalReranker
1010
from .http_bge import HTTPBGEReranker
11+
from .http_bge_strategy import HTTPBGERerankerStrategy
1112
from .noop import NoopReranker
1213

1314

@@ -45,4 +46,14 @@ def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None:
4546
if backend in {"noop", "none", "disabled"}:
4647
return NoopReranker()
4748

49+
if backend in {"http_bge_strategy", "bge_strategy"}:
50+
return HTTPBGERerankerStrategy(
51+
reranker_url=c.get("url") or c.get("endpoint") or c.get("reranker_url"),
52+
model=c.get("model", "bge-reranker-v2-m3"),
53+
timeout=int(c.get("timeout", 10)),
54+
headers_extra=c.get("headers_extra"),
55+
rerank_source=c.get("rerank_source"),
56+
reranker_strategy=c.get("reranker_strategy"),
57+
)
58+
4859
raise ValueError(f"Unknown reranker backend: {cfg.backend}")

src/memos/reranker/http_bge.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(
8080
model: str = "bge-reranker-v2-m3",
8181
timeout: int = 10,
8282
headers_extra: dict | None = None,
83-
rerank_source: list[str] | None = None,
83+
rerank_source: str | None = None,
8484
boost_weights: dict[str, float] | None = None,
8585
boost_default: float = 0.0,
8686
warn_unknown_filter_keys: bool = True,
@@ -107,7 +107,7 @@ def __init__(
107107
self.model = model
108108
self.timeout = timeout
109109
self.headers_extra = headers_extra or {}
110-
self.concat_source = rerank_source
110+
self.rerank_source = rerank_source
111111

112112
self.boost_weights = (
113113
DEFAULT_BOOST_WEIGHTS.copy()
@@ -152,8 +152,8 @@ def rerank(
152152
# Build a mapping from "payload docs index" -> "original graph_results index"
153153
# Only include items that have a non-empty string memory. This ensures that
154154
# any index returned by the server can be mapped back correctly.
155-
if self.concat_source:
156-
documents = concat_original_source(graph_results, self.concat_source)
155+
if self.rerank_source:
156+
documents = concat_original_source(graph_results, self.rerank_source)
157157
else:
158158
documents = [
159159
(_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m)

0 commit comments

Comments
 (0)