Skip to content

Commit fef40e9

Browse files
authored
feat: async add api (#410)
* feat: update manager for async add * feat: modify tree and simple_tree, TODO: STILL NOT ALIGN IN SOME FUNCTIONS * feat: modify schedule: add optional user_name in schedule message; modify user-name related graph query in scheduler * feat: finishe server router for async mode * feat: update graph db * fix: add label in core * feat: add tree mode in config * feat: default llm token 8000 * fix: thread * feat: search mode in client: fast * tests: fix
1 parent 4ed7574 commit fef40e9

File tree

13 files changed

+419
-113
lines changed

13 files changed

+419
-113
lines changed

evaluation/scripts/utils/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def search(self, query, user_id, top_k):
181181
"mem_cube_id": user_id,
182182
"conversation_id": "",
183183
"top_k": top_k,
184-
"mode": "mixture",
184+
"mode": "fast",
185185
"handle_pref_mem": False,
186186
},
187187
ensure_ascii=False,

src/memos/api/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def get_openai_config() -> dict[str, Any]:
2323
return {
2424
"model_name_or_path": os.getenv("MOS_CHAT_MODEL", "gpt-4o-mini"),
2525
"temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.8")),
26-
"max_tokens": int(os.getenv("MOS_MAX_TOKENS", "1024")),
26+
"max_tokens": int(os.getenv("MOS_MAX_TOKENS", "8000")),
2727
"top_p": float(os.getenv("MOS_TOP_P", "0.9")),
2828
"top_k": int(os.getenv("MOS_TOP_K", "50")),
2929
"remove_think_prefix": True,
@@ -672,6 +672,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
672672
"LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6),
673673
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
674674
},
675+
"mode": os.getenv("ASYNC_MODE", "sync"),
675676
},
676677
},
677678
"act_mem": {}

src/memos/api/routers/server_router.py

Lines changed: 234 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import json
12
import os
3+
import time
24
import traceback
35

6+
from datetime import datetime
47
from typing import TYPE_CHECKING, Any
58

69
from fastapi import APIRouter, HTTPException
10+
from fastapi.responses import StreamingResponse
711

812
from memos.api.config import APIConfig
913
from memos.api.product_models import (
@@ -32,8 +36,12 @@
3236
from memos.mem_scheduler.orm_modules.base_model import BaseDBManager
3337
from memos.mem_scheduler.scheduler_factory import SchedulerFactory
3438
from memos.mem_scheduler.schemas.general_schemas import (
39+
ADD_LABEL,
40+
MEM_READ_LABEL,
41+
PREF_ADD_LABEL,
3542
SearchMode,
3643
)
44+
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
3745
from memos.memories.textual.prefer_text_memory.config import (
3846
AdderConfigFactory,
3947
ExtractorConfigFactory,
@@ -233,6 +241,7 @@ def init_server():
233241
chat_llm=llm,
234242
process_llm=mem_reader.llm,
235243
db_engine=BaseDBManager.create_default_sqlite_engine(),
244+
mem_reader=mem_reader,
236245
)
237246
mem_scheduler.current_mem_cube = naive_mem_cube
238247
mem_scheduler.start()
@@ -477,6 +486,13 @@ def add_memories(add_req: APIADDRequest):
477486
if not target_session_id:
478487
target_session_id = "default_session"
479488

489+
# If text memory backend works in async mode, submit tasks to scheduler
490+
try:
491+
sync_mode = getattr(naive_mem_cube.text_mem, "mode", "sync")
492+
except Exception:
493+
sync_mode = "sync"
494+
logger.info(f"Add sync_mode mode is: {sync_mode}")
495+
480496
def _process_text_mem() -> list[dict[str, str]]:
481497
memories_local = mem_reader.get_memory(
482498
[add_req.messages],
@@ -485,6 +501,7 @@ def _process_text_mem() -> list[dict[str, str]]:
485501
"user_id": add_req.user_id,
486502
"session_id": target_session_id,
487503
},
504+
mode="fast" if sync_mode == "async" else "fine",
488505
)
489506
flattened_local = [mm for m in memories_local for mm in m]
490507
logger.info(f"Memory extraction completed for user {add_req.user_id}")
@@ -496,6 +513,34 @@ def _process_text_mem() -> list[dict[str, str]]:
496513
f"Added {len(mem_ids_local)} memories for user {add_req.user_id} "
497514
f"in session {add_req.session_id}: {mem_ids_local}"
498515
)
516+
if sync_mode == "async":
517+
try:
518+
message_item_read = ScheduleMessageItem(
519+
user_id=add_req.user_id,
520+
session_id=target_session_id,
521+
mem_cube_id=add_req.mem_cube_id,
522+
mem_cube=naive_mem_cube,
523+
label=MEM_READ_LABEL,
524+
content=json.dumps(mem_ids_local),
525+
timestamp=datetime.utcnow(),
526+
user_name=add_req.mem_cube_id,
527+
)
528+
mem_scheduler.submit_messages(messages=[message_item_read])
529+
logger.info(f"2105Submit messages!!!!!: {json.dumps(mem_ids_local)}")
530+
except Exception as e:
531+
logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True)
532+
else:
533+
message_item_add = ScheduleMessageItem(
534+
user_id=add_req.user_id,
535+
session_id=target_session_id,
536+
mem_cube_id=add_req.mem_cube_id,
537+
mem_cube=naive_mem_cube,
538+
label=ADD_LABEL,
539+
content=json.dumps(mem_ids_local),
540+
timestamp=datetime.utcnow(),
541+
user_name=add_req.mem_cube_id,
542+
)
543+
mem_scheduler.submit_messages(messages=[message_item_add])
499544
return [
500545
{
501546
"memory": memory.memory,
@@ -508,27 +553,46 @@ def _process_text_mem() -> list[dict[str, str]]:
508553
def _process_pref_mem() -> list[dict[str, str]]:
509554
if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true":
510555
return []
511-
pref_memories_local = naive_mem_cube.pref_mem.get_memory(
512-
[add_req.messages],
513-
type="chat",
514-
info={
515-
"user_id": add_req.user_id,
516-
"session_id": target_session_id,
517-
},
518-
)
519-
pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local)
520-
logger.info(
521-
f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} "
522-
f"in session {add_req.session_id}: {pref_ids_local}"
523-
)
524-
return [
525-
{
526-
"memory": memory.memory,
527-
"memory_id": memory_id,
528-
"memory_type": memory.metadata.preference_type,
529-
}
530-
for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False)
531-
]
556+
# Follow async behavior similar to core.py: enqueue when async
557+
if sync_mode == "async":
558+
try:
559+
messages_list = [add_req.messages]
560+
message_item_pref = ScheduleMessageItem(
561+
user_id=add_req.user_id,
562+
session_id=target_session_id,
563+
mem_cube_id=add_req.mem_cube_id,
564+
mem_cube=naive_mem_cube,
565+
label=PREF_ADD_LABEL,
566+
content=json.dumps(messages_list),
567+
timestamp=datetime.utcnow(),
568+
)
569+
mem_scheduler.submit_messages(messages=[message_item_pref])
570+
logger.info("Submitted preference add to scheduler (async mode)")
571+
except Exception as e:
572+
logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True)
573+
return []
574+
else:
575+
pref_memories_local = naive_mem_cube.pref_mem.get_memory(
576+
[add_req.messages],
577+
type="chat",
578+
info={
579+
"user_id": add_req.user_id,
580+
"session_id": target_session_id,
581+
},
582+
)
583+
pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local)
584+
logger.info(
585+
f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} "
586+
f"in session {add_req.session_id}: {pref_ids_local}"
587+
)
588+
return [
589+
{
590+
"memory": memory.memory,
591+
"memory_id": memory_id,
592+
"memory_type": memory.metadata.preference_type,
593+
}
594+
for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False)
595+
]
532596

533597
with ContextThreadPoolExecutor(max_workers=2) as executor:
534598
text_future = executor.submit(_process_text_mem)
@@ -542,6 +606,155 @@ def _process_pref_mem() -> list[dict[str, str]]:
542606
)
543607

544608

609+
@router.get("/scheduler/status", summary="Get scheduler running task count")
610+
def scheduler_status():
611+
"""
612+
Return current running tasks count from scheduler dispatcher.
613+
Shape is consistent with /scheduler/wait.
614+
"""
615+
try:
616+
running = mem_scheduler.dispatcher.get_running_tasks()
617+
running_count = len(running)
618+
now_ts = time.time()
619+
620+
return {
621+
"message": "ok",
622+
"data": {
623+
"running_tasks": running_count,
624+
"timestamp": now_ts,
625+
},
626+
}
627+
628+
except Exception as err:
629+
logger.error("Failed to get scheduler status: %s", traceback.format_exc())
630+
631+
raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err
632+
633+
634+
@router.post("/scheduler/wait", summary="Wait until scheduler is idle")
635+
def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
636+
"""
637+
Block until scheduler has no running tasks, or timeout.
638+
We return a consistent structured payload so callers can
639+
tell whether this was a clean flush or a timeout.
640+
641+
Args:
642+
timeout_seconds: max seconds to wait
643+
poll_interval: seconds between polls
644+
"""
645+
start = time.time()
646+
try:
647+
while True:
648+
running = mem_scheduler.dispatcher.get_running_tasks()
649+
running_count = len(running)
650+
elapsed = time.time() - start
651+
652+
# success -> scheduler is idle
653+
if running_count == 0:
654+
return {
655+
"message": "idle",
656+
"data": {
657+
"running_tasks": 0,
658+
"waited_seconds": round(elapsed, 3),
659+
"timed_out": False,
660+
},
661+
}
662+
663+
# timeout check
664+
if elapsed > timeout_seconds:
665+
return {
666+
"message": "timeout",
667+
"data": {
668+
"running_tasks": running_count,
669+
"waited_seconds": round(elapsed, 3),
670+
"timed_out": True,
671+
},
672+
}
673+
674+
time.sleep(poll_interval)
675+
676+
except Exception as err:
677+
logger.error(
678+
"Failed while waiting for scheduler: %s",
679+
traceback.format_exc(),
680+
)
681+
raise HTTPException(
682+
status_code=500,
683+
detail="Failed while waiting for scheduler",
684+
) from err
685+
686+
687+
@router.get("/scheduler/wait/stream", summary="Stream scheduler progress (SSE)")
688+
def scheduler_wait_stream(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
689+
"""
690+
Stream scheduler progress via Server-Sent Events (SSE).
691+
692+
Contract:
693+
- We emit periodic heartbeat frames while tasks are still running.
694+
- Each heartbeat frame is JSON, prefixed with "data: ".
695+
- On final frame, we include status = "idle" or "timeout" and timed_out flag,
696+
with the same semantics as /scheduler/wait.
697+
698+
Example curl:
699+
curl -N "${API_HOST}/product/scheduler/wait/stream?timeout_seconds=10&poll_interval=0.5"
700+
"""
701+
702+
def event_generator():
703+
start = time.time()
704+
try:
705+
while True:
706+
running = mem_scheduler.dispatcher.get_running_tasks()
707+
running_count = len(running)
708+
elapsed = time.time() - start
709+
710+
# heartbeat frame
711+
heartbeat_payload = {
712+
"running_tasks": running_count,
713+
"elapsed_seconds": round(elapsed, 3),
714+
"status": "running" if running_count > 0 else "idle",
715+
}
716+
yield "data: " + json.dumps(heartbeat_payload, ensure_ascii=False) + "\n\n"
717+
718+
# scheduler is idle -> final frame + break
719+
if running_count == 0:
720+
final_payload = {
721+
"running_tasks": 0,
722+
"elapsed_seconds": round(elapsed, 3),
723+
"status": "idle",
724+
"timed_out": False,
725+
}
726+
yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n"
727+
break
728+
729+
# timeout -> final frame + break
730+
if elapsed > timeout_seconds:
731+
final_payload = {
732+
"running_tasks": running_count,
733+
"elapsed_seconds": round(elapsed, 3),
734+
"status": "timeout",
735+
"timed_out": True,
736+
}
737+
yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n"
738+
break
739+
740+
time.sleep(poll_interval)
741+
742+
except Exception as e:
743+
err_payload = {
744+
"status": "error",
745+
"detail": "stream_failed",
746+
"exception": str(e),
747+
}
748+
logger.error(
749+
"Failed streaming scheduler wait: %s: %s",
750+
e,
751+
traceback.format_exc(),
752+
)
753+
yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n"
754+
755+
return StreamingResponse(event_generator(), media_type="text/event-stream")
756+
757+
545758
@router.post("/chat/complete", summary="Chat with MemOS (Complete Response)")
546759
def chat_complete(chat_req: APIChatCompleteRequest):
547760
"""Chat with MemOS for a specific user. Returns complete response (non-streaming)."""

src/memos/graph_dbs/nebular.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def remove_oldest_memory(
439439
Args:
440440
memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
441441
keep_latest (int): Number of latest WorkingMemory entries to keep.
442+
user_name(str): optional user_name.
442443
"""
443444
try:
444445
user_name = user_name if user_name else self.config.user_name
@@ -685,8 +686,7 @@ def get_node(
685686
Returns:
686687
dict: Node properties as key-value pairs, or None if not found.
687688
"""
688-
user_name = user_name if user_name else self.config.user_name
689-
filter_clause = f'n.user_name = "{user_name}" AND n.id = "{id}"'
689+
filter_clause = f'n.id = "{id}"'
690690
return_fields = self._build_return_fields(include_embedding)
691691
gql = f"""
692692
MATCH (n@Memory)
@@ -730,16 +730,13 @@ def get_nodes(
730730
"""
731731
if not ids:
732732
return []
733-
734-
user_name = user_name if user_name else self.config.user_name
735-
where_user = f" AND n.user_name = '{user_name}'"
736733
# Safe formatting of the ID list
737734
id_list = ",".join(f'"{_id}"' for _id in ids)
738735

739736
return_fields = self._build_return_fields(include_embedding)
740737
query = f"""
741738
MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
742-
WHERE n.id IN [{id_list}] {where_user}
739+
WHERE n.id IN [{id_list}]
743740
RETURN {return_fields}
744741
"""
745742
nodes = []
@@ -1497,10 +1494,10 @@ def _ensure_space_exists(cls, tmp_client, cfg):
14971494
return
14981495

14991496
try:
1500-
res = tmp_client.execute("SHOW GRAPHS;")
1497+
res = tmp_client.execute("SHOW GRAPHS")
15011498
existing = {row.values()[0].as_string() for row in res}
15021499
if db_name not in existing:
1503-
tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type;")
1500+
tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type")
15041501
logger.info(f"✅ Graph `{db_name}` created before session binding.")
15051502
else:
15061503
logger.debug(f"Graph `{db_name}` already exists.")

src/memos/graph_dbs/neo4j.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def remove_oldest_memory(
149149
Args:
150150
memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
151151
keep_latest (int): Number of latest WorkingMemory entries to keep.
152+
user_name(str): optional user_name.
152153
"""
153154
user_name = user_name if user_name else self.config.user_name
154155
query = f"""

0 commit comments

Comments
 (0)