Skip to content

Commit b3b0baa

Browse files
Wang-Daojiyuan.wangfridayL
authored
Feat/standardized preference field (#440)
* standardized preference field * fix pre_commit --------- Co-authored-by: yuan.wang <[email protected]> Co-authored-by: chunyu li <[email protected]>
1 parent b3ec17a commit b3b0baa

File tree

6 files changed

+41
-45
lines changed

6 files changed

+41
-45
lines changed

src/memos/memories/textual/item.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,8 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata):
196196
dialog_id: str | None = Field(default=None, description="ID of the dialog.")
197197
original_text: str | None = Field(default=None, description="String of the dialog.")
198198
embedding: list[float] | None = Field(default=None, description="Vector of the dialog.")
199-
explicit_preference: str | None = Field(default=None, description="Explicit preference.")
199+
preference: str | None = Field(default=None, description="Preference.")
200200
created_at: str | None = Field(default=None, description="Timestamp of the dialog.")
201-
implicit_preference: str | None = Field(default=None, description="Implicit preference.")
202201

203202

204203
class TextualMemoryItem(BaseModel):

src/memos/memories/textual/prefer_text_memory/adder.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def _update_memory_op_trace(
103103
new_memories: list[TextualMemoryItem],
104104
retrieved_memories: list[MilvusVecDBItem],
105105
collection_name: str,
106-
preference_type: str,
107106
) -> list[str] | str:
108107
# create new vec db items
109108
new_vec_db_items: list[MilvusVecDBItem] = []
@@ -124,17 +123,19 @@ def _update_memory_op_trace(
124123
{
125124
"id": new_memory.id,
126125
"context_summary": new_memory.memory,
127-
"preference": new_memory.payload[preference_type],
126+
"preference": new_memory.payload["preference"],
128127
}
129128
for new_memory in new_vec_db_items
129+
if new_memory.payload.get("preference", None)
130130
]
131131
retrieved_mem_inputs = [
132132
{
133133
"id": mem.id,
134134
"context_summary": mem.memory,
135-
"preference": mem.payload[preference_type],
135+
"preference": mem.payload["preference"],
136136
}
137137
for mem in retrieved_memories
138+
if mem.payload.get("preference", None)
138139
]
139140

140141
rsp = self._judge_update_or_add_trace_op(
@@ -168,7 +169,7 @@ def execute_op(
168169
elif op_type == "update":
169170
if op["target_id"] in retrieved_mem_db_item_map:
170171
update_mem_db_item = retrieved_mem_db_item_map[op["target_id"]]
171-
update_mem_db_item.payload[preference_type] = op["new_preference"]
172+
update_mem_db_item.payload["preference"] = op["new_preference"]
172173
update_mem_db_item.payload["updated_at"] = datetime.now().isoformat()
173174
update_mem_db_item.memory = op["new_context_summary"]
174175
update_mem_db_item.original_text = op["new_context_summary"]
@@ -198,7 +199,6 @@ def _update_memory_fine(
198199
new_memory: TextualMemoryItem,
199200
retrieved_memories: list[MilvusVecDBItem],
200201
collection_name: str,
201-
preference_type: str,
202202
) -> str:
203203
payload = new_memory.to_dict()["metadata"]
204204
fields_to_remove = {"dialog_id", "original_text", "embedding"}
@@ -211,19 +211,15 @@ def _update_memory_fine(
211211
payload=payload,
212212
)
213213

214-
new_mem_input = {
215-
"memory": new_memory.memory,
216-
"preference": new_memory.metadata.explicit_preference
217-
if preference_type == "explicit_preference"
218-
else new_memory.metadata.implicit_preference,
219-
}
214+
new_mem_input = {"memory": new_memory.memory, "preference": new_memory.metadata.preference}
220215
retrieved_mem_inputs = [
221216
{
222217
"id": mem.id,
223218
"memory": mem.memory,
224-
"preference": mem.payload[preference_type],
219+
"preference": mem.payload["preference"],
225220
}
226221
for mem in retrieved_memories
222+
if mem.payload.get("preference", None)
227223
]
228224
rsp = self._judge_update_or_add_fine(
229225
new_mem=json.dumps(new_mem_input),
@@ -240,7 +236,7 @@ def _update_memory_fine(
240236
)
241237
if need_update and update_item and rsp:
242238
update_vec_db_item = update_item[0]
243-
update_vec_db_item.payload[preference_type] = rsp["new_preference"]
239+
update_vec_db_item.payload["preference"] = rsp["new_preference"]
244240
update_vec_db_item.payload["updated_at"] = vec_db_item.payload["updated_at"]
245241
update_vec_db_item.memory = rsp["new_memory"]
246242
update_vec_db_item.original_text = vec_db_item.original_text
@@ -287,23 +283,19 @@ def _update_memory(
287283
new_memory: TextualMemoryItem,
288284
retrieved_memories: list[MilvusVecDBItem],
289285
collection_name: str,
290-
preference_type: str,
291286
update_mode: str = "fast",
292287
) -> list[str] | str | None:
293288
"""Update the memory.
294289
Args:
295290
new_memory: TextualMemoryItem
296291
retrieved_memories: list[MilvusVecDBItem]
297292
collection_name: str
298-
preference_type: str
299293
update_mode: str, "fast" or "fine"
300294
"""
301295
if update_mode == "fast":
302296
return self._update_memory_fast(new_memory, retrieved_memories, collection_name)
303297
elif update_mode == "fine":
304-
return self._update_memory_fine(
305-
new_memory, retrieved_memories, collection_name, preference_type
306-
)
298+
return self._update_memory_fine(new_memory, retrieved_memories, collection_name)
307299
else:
308300
raise ValueError(f"Invalid update mode: {update_mode}")
309301

@@ -330,7 +322,6 @@ def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str |
330322
memory,
331323
search_results,
332324
collection_name,
333-
preference_type,
334325
update_mode=os.getenv("PREFERENCE_ADDER_MODE", "fast"),
335326
)
336327

@@ -369,18 +360,24 @@ def process_memory_batch(self, memories: list[TextualMemoryItem], *args, **kwarg
369360
explicit_recalls = list({recall.id: recall for recall in explicit_recalls}.values())
370361
implicit_recalls = list({recall.id: recall for recall in implicit_recalls}.values())
371362

372-
explicit_added_ids = self._update_memory_op_trace(
373-
explicit_new_mems,
374-
explicit_recalls,
375-
pref_type_collection_map["explicit_preference"],
376-
"explicit_preference",
377-
)
378-
implicit_added_ids = self._update_memory_op_trace(
379-
implicit_new_mems,
380-
implicit_recalls,
381-
pref_type_collection_map["implicit_preference"],
382-
"implicit_preference",
383-
)
363+
# 使用线程池并行处理显式和隐式偏好
364+
with ContextThreadPoolExecutor(max_workers=2) as executor:
365+
explicit_future = executor.submit(
366+
self._update_memory_op_trace,
367+
explicit_new_mems,
368+
explicit_recalls,
369+
pref_type_collection_map["explicit_preference"],
370+
)
371+
implicit_future = executor.submit(
372+
self._update_memory_op_trace,
373+
implicit_new_mems,
374+
implicit_recalls,
375+
pref_type_collection_map["implicit_preference"],
376+
)
377+
378+
explicit_added_ids = explicit_future.result()
379+
implicit_added_ids = implicit_future.result()
380+
384381
return explicit_added_ids + implicit_added_ids
385382

386383
def process_memory_single(

src/memos/memories/textual/prefer_text_memory/extractor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, A
6767
response = self.llm_provider.generate([{"role": "user", "content": prompt}])
6868
response = response.strip().replace("```json", "").replace("```", "").strip()
6969
result = json.loads(response)
70+
for d in result:
71+
d["preference"] = d.pop("explicit_preference")
7072
return result
7173
except Exception as e:
7274
logger.error(f"Error extracting explicit preference: {e}, return None")
@@ -88,6 +90,7 @@ def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, A
8890
response = self.llm_provider.generate([{"role": "user", "content": prompt}])
8991
response = response.strip().replace("```json", "").replace("```", "").strip()
9092
result = json.loads(response)
93+
result["preference"] = result.pop("implicit_preference")
9194
return result
9295
except Exception as e:
9396
logger.error(f"Error extracting implicit preferences: {e}, return None")

src/memos/memories/textual/prefer_text_memory/retrievers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def retrieve(
106106
metadata=PreferenceTextualMemoryMetadata(**pref.payload),
107107
)
108108
for pref in explicit_prefs
109-
if pref.payload["explicit_preference"]
109+
if pref.payload.get("preference", None)
110110
]
111111

112112
implicit_prefs_mem = [
@@ -116,7 +116,7 @@ def retrieve(
116116
metadata=PreferenceTextualMemoryMetadata(**pref.payload),
117117
)
118118
for pref in implicit_prefs
119-
if pref.payload["implicit_preference"]
119+
if pref.payload.get("preference", None)
120120
]
121121

122122
reranker_map = {

src/memos/memories/textual/prefer_text_memory/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,8 @@ def deduplicate_preferences(
4646

4747
for i, pref in enumerate(prefs):
4848
# Extract preference text
49-
if hasattr(pref.metadata, "implicit_preference") and pref.metadata.implicit_preference:
50-
text = pref.metadata.implicit_preference
51-
elif hasattr(pref.metadata, "explicit_preference") and pref.metadata.explicit_preference:
52-
text = pref.metadata.explicit_preference
49+
if hasattr(pref.metadata, "preference") and pref.metadata.preference:
50+
text = pref.metadata.preference
5351
else:
5452
text = pref.memory
5553

src/memos/templates/instruction_completion.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@ def instruct_completion(
1212
implicit_pref = []
1313
for memory in memories:
1414
pref_type = memory.get("metadata", {}).get("preference_type")
15+
pref = memory.get("metadata", {}).get("preference", None)
16+
if not pref:
17+
continue
1518
if pref_type == "explicit_preference":
16-
pref = memory.get("metadata", {}).get("explicit_preference", None)
17-
if pref:
18-
explicit_pref.append(pref)
19+
explicit_pref.append(pref)
1920
elif pref_type == "implicit_preference":
20-
pref = memory.get("metadata", {}).get("implicit_preference", None)
21-
if pref:
22-
implicit_pref.append(pref)
21+
implicit_pref.append(pref)
2322

2423
explicit_pref_str = (
2524
"Explicit Preference:\n"

0 commit comments

Comments
 (0)