Skip to content

Commit aa80863

Browse files
Wang-Daojiyuan.wang
andauthored
Feat/pref optimize update (#425)
* add hybrid search and fine extractor * add dialog and modify spliter chunk * optmize the update and retriever code * modify pref field * add pref mem update srategy * add pref mem update srategy * fix bug in pre_commit * modify pref filed * fix bug * fix pre_commit * fix bug in adder --------- Co-authored-by: yuan.wang <[email protected]>
1 parent 0e7128e commit aa80863

File tree

13 files changed

+130
-50
lines changed

13 files changed

+130
-50
lines changed

evaluation/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,4 @@ get `questions_32k.csv` and `shared_contexts_32k.jsonl` from https://huggingface
8484
# Specify the model and memory backend you want to use (e.g., mem0, zep, etc.)
8585
# If you want to use MIRIX, edit the the configuration in ./scripts/personamem/config.yaml
8686
./scripts/run_pm_eval.sh
87-
```
87+
```

evaluation/scripts/PrefEval/pref_mem0.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def add_memory_for_line(
5656

5757
for idx, _ in enumerate(conversation[::2]):
5858
msg_idx = idx * 2
59-
record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}"
59+
record_id = f"{lib}_user_pref_eval_{i}_{version}_{msg_idx!s}"
6060
timestamp_add = int(time.time() * 100)
6161

6262
if record_id not in success_records:

evaluation/scripts/PrefEval/pref_memobase.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from openai import OpenAI
1313
from tqdm import tqdm
1414

15+
1516
ROOT_DIR = os.path.dirname(
1617
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1718
)
@@ -68,7 +69,7 @@ def add_memory_for_line(
6869
)
6970
for idx, _ in enumerate(conversation[::2]):
7071
msg_idx = idx * 2
71-
record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}"
72+
record_id = f"{lib}_user_pref_eval_{i}_{version}_{msg_idx!s}"
7273

7374
if record_id not in success_records:
7475
mem_client.add(messages=conversation[msg_idx : msg_idx + 2], user_id=user_id)

evaluation/scripts/PrefEval/pref_memos.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from openai import OpenAI
1313
from tqdm import tqdm
1414

15+
1516
ROOT_DIR = os.path.dirname(
1617
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1718
)
@@ -49,7 +50,7 @@ def add_memory_for_line(
4950

5051
for idx, _ in enumerate(conversation[::2]):
5152
msg_idx = idx * 2
52-
record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}"
53+
record_id = f"{lib}_user_pref_eval_{i}_{version}_{msg_idx!s}"
5354

5455
if record_id not in success_records:
5556
mem_client.add(

evaluation/scripts/PrefEval/pref_memu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from openai import OpenAI
1515
from tqdm import tqdm
1616

17+
1718
ROOT_DIR = os.path.dirname(
1819
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1920
)
@@ -56,7 +57,7 @@ def add_memory_for_line(
5657

5758
for idx, _ in enumerate(conversation[::2]):
5859
msg_idx = idx * 2
59-
record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}"
60+
record_id = f"{lib}_user_pref_eval_{i}_{version}_{msg_idx!s}"
6061

6162
if record_id not in success_records:
6263
mem_client.add(

evaluation/scripts/PrefEval/pref_supermemory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from openai import OpenAI
1313
from tqdm import tqdm
1414

15+
1516
ROOT_DIR = os.path.dirname(
1617
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1718
)
@@ -54,7 +55,7 @@ def add_memory_for_line(
5455

5556
for idx, _ in enumerate(conversation[::2]):
5657
msg_idx = idx * 2
57-
record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}"
58+
record_id = f"{lib}_user_pref_eval_{i}_{version}_{msg_idx!s}"
5859

5960
if record_id not in success_records:
6061
mem_client.add(

evaluation/scripts/PrefEval/pref_zep.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from openai import OpenAI
1515
from tqdm import tqdm
1616

17+
1718
ROOT_DIR = os.path.dirname(
1819
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1920
)
@@ -56,7 +57,7 @@ def add_memory_for_line(
5657

5758
for idx, _ in enumerate(conversation[::2]):
5859
msg_idx = idx * 2
59-
record_id = f"{lib}_user_pref_eval_{i}_{version}_{str(msg_idx)}"
60+
record_id = f"{lib}_user_pref_eval_{i}_{version}_{msg_idx!s}"
6061

6162
if record_id not in success_records:
6263
mem_client.add(

evaluation/scripts/personamem/pm_ingestion.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from tqdm import tqdm
1212

13+
1314
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1415

1516

@@ -171,7 +172,9 @@ def ingest_conv(row_data, context, version, conv_idx, frame, success_records, f)
171172
client = MemosApiOnlineClient()
172173

173174
try:
174-
ingest_session(session=context, user_id=user_id, session_id=conv_idx, frame=frame, client=client)
175+
ingest_session(
176+
session=context, user_id=user_id, session_id=conv_idx, frame=frame, client=client
177+
)
175178
print(f"✅ Ingestion of conversation {conv_idx} completed")
176179
print("=" * 80)
177180

@@ -187,10 +190,9 @@ def main(frame, version, num_workers=2, clear=False):
187190
os.makedirs(f"results/pm/{frame}-{version}/", exist_ok=True)
188191
record_file = f"results/pm/{frame}-{version}/success_records.txt"
189192

190-
if clear:
191-
if os.path.exists(record_file):
192-
os.remove(record_file)
193-
print("🧹 Cleared progress records")
193+
if clear and os.path.exists(record_file):
194+
os.remove(record_file)
195+
print("🧹 Cleared progress records")
194196

195197
print("\n" + "=" * 80)
196198
print(f"🚀 PERSONAMEM INGESTION - {frame.upper()} v{version}".center(80))
@@ -205,15 +207,20 @@ def main(frame, version, num_workers=2, clear=False):
205207

206208
success_records = set()
207209
if os.path.exists(record_file):
208-
with open(record_file, "r") as f:
209-
success_records = set(line.strip() for line in f)
210-
print(f"📊 Found {len(success_records)} completed conversations, {total_rows - len(success_records)} remaining")
210+
with open(record_file) as f:
211+
success_records = {line.strip() for line in f}
212+
print(
213+
f"📊 Found {len(success_records)} completed conversations, {total_rows - len(success_records)} remaining"
214+
)
211215

212216
start_time = datetime.now()
213217
all_data = list(load_rows_with_context(question_csv_path, context_jsonl_path))
214218

215-
pending_data = [(idx, row_data, context) for idx, (row_data, context) in enumerate(all_data)
216-
if str(idx) not in success_records]
219+
pending_data = [
220+
(idx, row_data, context)
221+
for idx, (row_data, context) in enumerate(all_data)
222+
if str(idx) not in success_records
223+
]
217224

218225
if not pending_data:
219226
print("✅ All conversations have been processed!")
@@ -232,16 +239,16 @@ def main(frame, version, num_workers=2, clear=False):
232239
conv_idx=idx,
233240
frame=frame,
234241
success_records=success_records,
235-
f=f
242+
f=f,
236243
)
237244
futures.append(future)
238245

239246
completed_count = 0
240247
for future in tqdm(
241-
as_completed(futures), total=len(futures), desc="Processing conversations"
248+
as_completed(futures), total=len(futures), desc="Processing conversations"
242249
):
243250
try:
244-
result = future.result()
251+
future.result()
245252
completed_count += 1
246253
except Exception as exc:
247254
print(f"\n❌ Conversation generated an exception: {exc}")
@@ -261,13 +268,28 @@ def main(frame, version, num_workers=2, clear=False):
261268

262269
if __name__ == "__main__":
263270
parser = argparse.ArgumentParser(description="PersonaMem Ingestion Script")
264-
parser.add_argument("--lib", type=str,
265-
choices=["memos-api-online", "mem0", "mem0_graph", "memos-api", "memobase", "memu",
266-
"supermemory", "zep"],
267-
default='memos-api')
268-
parser.add_argument("--version", type=str, default="default", help="Version of the evaluation framework.")
269-
parser.add_argument("--workers", type=int, default=3, help="Number of parallel workers for processing users.")
271+
parser.add_argument(
272+
"--lib",
273+
type=str,
274+
choices=[
275+
"memos-api-online",
276+
"mem0",
277+
"mem0_graph",
278+
"memos-api",
279+
"memobase",
280+
"memu",
281+
"supermemory",
282+
"zep",
283+
],
284+
default="memos-api",
285+
)
286+
parser.add_argument(
287+
"--version", type=str, default="default", help="Version of the evaluation framework."
288+
)
289+
parser.add_argument(
290+
"--workers", type=int, default=3, help="Number of parallel workers for processing users."
291+
)
270292
parser.add_argument("--clear", action="store_true", help="Clear progress and start fresh")
271293
args = parser.parse_args()
272294

273-
main(frame=args.lib, version=args.version, num_workers=args.workers, clear=args.clear)
295+
main(frame=args.lib, version=args.version, num_workers=args.workers, clear=args.clear)

evaluation/scripts/personamem/pm_metric.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,12 +353,23 @@ def print_summary(results):
353353
parser.add_argument(
354354
"--lib",
355355
type=str,
356-
choices=["zep", "mem0", "mem0_graph", "memos-api", "memos-api-online", "memobase", "memu", "supermemory"],
356+
choices=[
357+
"zep",
358+
"mem0",
359+
"mem0_graph",
360+
"memos-api",
361+
"memos-api-online",
362+
"memobase",
363+
"memu",
364+
"supermemory",
365+
],
357366
required=True,
358367
help="Memory library to evaluate",
359368
default="memos-api",
360369
)
361-
parser.add_argument("--version", type=str, default="default", help="Evaluation framework version")
370+
parser.add_argument(
371+
"--version", type=str, default="default", help="Evaluation framework version"
372+
)
362373

363374
args = parser.parse_args()
364375
lib, version = args.lib, args.version

evaluation/scripts/personamem/pm_responses.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from openai import OpenAI
1111
from tqdm import tqdm
1212

13+
1314
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1415
import re
1516

@@ -153,9 +154,9 @@ def main(frame, version, num_runs=3, num_workers=4):
153154
future_to_user_id[future] = user_id
154155

155156
for future in tqdm(
156-
as_completed(future_to_user_id),
157-
total=len(future_to_user_id),
158-
desc="📝 Generating responses",
157+
as_completed(future_to_user_id),
158+
total=len(future_to_user_id),
159+
desc="📝 Generating responses",
159160
):
160161
user_id = future_to_user_id[future]
161162
try:
@@ -184,12 +185,30 @@ def main(frame, version, num_runs=3, num_workers=4):
184185

185186
if __name__ == "__main__":
186187
parser = argparse.ArgumentParser(description="PersonaMem Response Generation Script")
187-
parser.add_argument("--lib", type=str,
188-
choices=["memos-api-online", "zep", "mem0", "mem0_graph", "memos-api", "memobase", "memu",
189-
"supermemory"], default='memos-api')
190-
parser.add_argument("--version", type=str, default="default", help="Version of the evaluation framework.")
191-
parser.add_argument("--num_runs", type=int, default=3, help="Number of runs for LLM-as-a-Judge evaluation.")
192-
parser.add_argument("--workers", type=int, default=10, help="Number of worker threads to use for processing.")
188+
parser.add_argument(
189+
"--lib",
190+
type=str,
191+
choices=[
192+
"memos-api-online",
193+
"zep",
194+
"mem0",
195+
"mem0_graph",
196+
"memos-api",
197+
"memobase",
198+
"memu",
199+
"supermemory",
200+
],
201+
default="memos-api",
202+
)
203+
parser.add_argument(
204+
"--version", type=str, default="default", help="Version of the evaluation framework."
205+
)
206+
parser.add_argument(
207+
"--num_runs", type=int, default=3, help="Number of runs for LLM-as-a-Judge evaluation."
208+
)
209+
parser.add_argument(
210+
"--workers", type=int, default=10, help="Number of worker threads to use for processing."
211+
)
193212

194213
args = parser.parse_args()
195214
main(frame=args.lib, version=args.version, num_runs=args.num_runs, num_workers=args.workers)

0 commit comments

Comments
 (0)