Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion evaluation/scripts/utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def search(self, query, user_id, top_k):
"mem_cube_id": user_id,
"conversation_id": "",
"top_k": top_k,
"mode": "mixture",
"mode": "fast",
"handle_pref_mem": False,
},
ensure_ascii=False,
Expand Down
3 changes: 2 additions & 1 deletion src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_openai_config() -> dict[str, Any]:
return {
"model_name_or_path": os.getenv("MOS_CHAT_MODEL", "gpt-4o-mini"),
"temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.8")),
"max_tokens": int(os.getenv("MOS_MAX_TOKENS", "1024")),
"max_tokens": int(os.getenv("MOS_MAX_TOKENS", "8000")),
"top_p": float(os.getenv("MOS_TOP_P", "0.9")),
"top_k": int(os.getenv("MOS_TOP_K", "50")),
"remove_think_prefix": True,
Expand Down Expand Up @@ -672,6 +672,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
"LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6),
"UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6),
},
"mode": os.getenv("ASYNC_MODE", "sync"),
},
},
"act_mem": {}
Expand Down
255 changes: 234 additions & 21 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import json
import os
import time
import traceback

from datetime import datetime
from typing import TYPE_CHECKING, Any

from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse

from memos.api.config import APIConfig
from memos.api.product_models import (
Expand Down Expand Up @@ -32,8 +36,12 @@
from memos.mem_scheduler.orm_modules.base_model import BaseDBManager
from memos.mem_scheduler.scheduler_factory import SchedulerFactory
from memos.mem_scheduler.schemas.general_schemas import (
ADD_LABEL,
MEM_READ_LABEL,
PREF_ADD_LABEL,
SearchMode,
)
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.memories.textual.prefer_text_memory.config import (
AdderConfigFactory,
ExtractorConfigFactory,
Expand Down Expand Up @@ -233,6 +241,7 @@ def init_server():
chat_llm=llm,
process_llm=mem_reader.llm,
db_engine=BaseDBManager.create_default_sqlite_engine(),
mem_reader=mem_reader,
)
mem_scheduler.current_mem_cube = naive_mem_cube
mem_scheduler.start()
Expand Down Expand Up @@ -477,6 +486,13 @@ def add_memories(add_req: APIADDRequest):
if not target_session_id:
target_session_id = "default_session"

# If text memory backend works in async mode, submit tasks to scheduler
try:
sync_mode = getattr(naive_mem_cube.text_mem, "mode", "sync")
except Exception:
sync_mode = "sync"
logger.info(f"Add sync_mode mode is: {sync_mode}")

def _process_text_mem() -> list[dict[str, str]]:
memories_local = mem_reader.get_memory(
[add_req.messages],
Expand All @@ -485,6 +501,7 @@ def _process_text_mem() -> list[dict[str, str]]:
"user_id": add_req.user_id,
"session_id": target_session_id,
},
mode="fast" if sync_mode == "async" else "fine",
)
flattened_local = [mm for m in memories_local for mm in m]
logger.info(f"Memory extraction completed for user {add_req.user_id}")
Expand All @@ -496,6 +513,34 @@ def _process_text_mem() -> list[dict[str, str]]:
f"Added {len(mem_ids_local)} memories for user {add_req.user_id} "
f"in session {add_req.session_id}: {mem_ids_local}"
)
if sync_mode == "async":
try:
message_item_read = ScheduleMessageItem(
user_id=add_req.user_id,
session_id=target_session_id,
mem_cube_id=add_req.mem_cube_id,
mem_cube=naive_mem_cube,
label=MEM_READ_LABEL,
content=json.dumps(mem_ids_local),
timestamp=datetime.utcnow(),
user_name=add_req.mem_cube_id,
)
mem_scheduler.submit_messages(messages=[message_item_read])
logger.info(f"2105Submit messages!!!!!: {json.dumps(mem_ids_local)}")
except Exception as e:
logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True)
else:
message_item_add = ScheduleMessageItem(
user_id=add_req.user_id,
session_id=target_session_id,
mem_cube_id=add_req.mem_cube_id,
mem_cube=naive_mem_cube,
label=ADD_LABEL,
content=json.dumps(mem_ids_local),
timestamp=datetime.utcnow(),
user_name=add_req.mem_cube_id,
)
mem_scheduler.submit_messages(messages=[message_item_add])
return [
{
"memory": memory.memory,
Expand All @@ -508,27 +553,46 @@ def _process_text_mem() -> list[dict[str, str]]:
def _process_pref_mem() -> list[dict[str, str]]:
if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true":
return []
pref_memories_local = naive_mem_cube.pref_mem.get_memory(
[add_req.messages],
type="chat",
info={
"user_id": add_req.user_id,
"session_id": target_session_id,
},
)
pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local)
logger.info(
f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} "
f"in session {add_req.session_id}: {pref_ids_local}"
)
return [
{
"memory": memory.memory,
"memory_id": memory_id,
"memory_type": memory.metadata.preference_type,
}
for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False)
]
# Follow async behavior similar to core.py: enqueue when async
if sync_mode == "async":
try:
messages_list = [add_req.messages]
message_item_pref = ScheduleMessageItem(
user_id=add_req.user_id,
session_id=target_session_id,
mem_cube_id=add_req.mem_cube_id,
mem_cube=naive_mem_cube,
label=PREF_ADD_LABEL,
content=json.dumps(messages_list),
timestamp=datetime.utcnow(),
)
mem_scheduler.submit_messages(messages=[message_item_pref])
logger.info("Submitted preference add to scheduler (async mode)")
except Exception as e:
logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True)
return []
else:
pref_memories_local = naive_mem_cube.pref_mem.get_memory(
[add_req.messages],
type="chat",
info={
"user_id": add_req.user_id,
"session_id": target_session_id,
},
)
pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local)
logger.info(
f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} "
f"in session {add_req.session_id}: {pref_ids_local}"
)
return [
{
"memory": memory.memory,
"memory_id": memory_id,
"memory_type": memory.metadata.preference_type,
}
for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False)
]

with ContextThreadPoolExecutor(max_workers=2) as executor:
text_future = executor.submit(_process_text_mem)
Expand All @@ -542,6 +606,155 @@ def _process_pref_mem() -> list[dict[str, str]]:
)


@router.get("/scheduler/status", summary="Get scheduler running task count")
def scheduler_status():
"""
Return current running tasks count from scheduler dispatcher.
Shape is consistent with /scheduler/wait.
"""
try:
running = mem_scheduler.dispatcher.get_running_tasks()
running_count = len(running)
now_ts = time.time()

return {
"message": "ok",
"data": {
"running_tasks": running_count,
"timestamp": now_ts,
},
}

except Exception as err:
logger.error("Failed to get scheduler status: %s", traceback.format_exc())

raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err


@router.post("/scheduler/wait", summary="Wait until scheduler is idle")
def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
"""
Block until scheduler has no running tasks, or timeout.
We return a consistent structured payload so callers can
tell whether this was a clean flush or a timeout.

Args:
timeout_seconds: max seconds to wait
poll_interval: seconds between polls
"""
start = time.time()
try:
while True:
running = mem_scheduler.dispatcher.get_running_tasks()
running_count = len(running)
elapsed = time.time() - start

# success -> scheduler is idle
if running_count == 0:
return {
"message": "idle",
"data": {
"running_tasks": 0,
"waited_seconds": round(elapsed, 3),
"timed_out": False,
},
}

# timeout check
if elapsed > timeout_seconds:
return {
"message": "timeout",
"data": {
"running_tasks": running_count,
"waited_seconds": round(elapsed, 3),
"timed_out": True,
},
}

time.sleep(poll_interval)

except Exception as err:
logger.error(
"Failed while waiting for scheduler: %s",
traceback.format_exc(),
)
raise HTTPException(
status_code=500,
detail="Failed while waiting for scheduler",
) from err


@router.get("/scheduler/wait/stream", summary="Stream scheduler progress (SSE)")
def scheduler_wait_stream(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
"""
Stream scheduler progress via Server-Sent Events (SSE).

Contract:
- We emit periodic heartbeat frames while tasks are still running.
- Each heartbeat frame is JSON, prefixed with "data: ".
- On final frame, we include status = "idle" or "timeout" and timed_out flag,
with the same semantics as /scheduler/wait.

Example curl:
curl -N "${API_HOST}/product/scheduler/wait/stream?timeout_seconds=10&poll_interval=0.5"
"""

def event_generator():
start = time.time()
try:
while True:
running = mem_scheduler.dispatcher.get_running_tasks()
running_count = len(running)
elapsed = time.time() - start

# heartbeat frame
heartbeat_payload = {
"running_tasks": running_count,
"elapsed_seconds": round(elapsed, 3),
"status": "running" if running_count > 0 else "idle",
}
yield "data: " + json.dumps(heartbeat_payload, ensure_ascii=False) + "\n\n"

# scheduler is idle -> final frame + break
if running_count == 0:
final_payload = {
"running_tasks": 0,
"elapsed_seconds": round(elapsed, 3),
"status": "idle",
"timed_out": False,
}
yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n"
break

# timeout -> final frame + break
if elapsed > timeout_seconds:
final_payload = {
"running_tasks": running_count,
"elapsed_seconds": round(elapsed, 3),
"status": "timeout",
"timed_out": True,
}
yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n"
break

time.sleep(poll_interval)

except Exception as e:
err_payload = {
"status": "error",
"detail": "stream_failed",
"exception": str(e),
}
logger.error(
"Failed streaming scheduler wait: %s: %s",
e,
traceback.format_exc(),
)
yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n"

return StreamingResponse(event_generator(), media_type="text/event-stream")


@router.post("/chat/complete", summary="Chat with MemOS (Complete Response)")
def chat_complete(chat_req: APIChatCompleteRequest):
"""Chat with MemOS for a specific user. Returns complete response (non-streaming)."""
Expand Down
13 changes: 5 additions & 8 deletions src/memos/graph_dbs/nebular.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def remove_oldest_memory(
Args:
memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
keep_latest (int): Number of latest WorkingMemory entries to keep.
user_name(str): optional user_name.
"""
try:
user_name = user_name if user_name else self.config.user_name
Expand Down Expand Up @@ -685,8 +686,7 @@ def get_node(
Returns:
dict: Node properties as key-value pairs, or None if not found.
"""
user_name = user_name if user_name else self.config.user_name
filter_clause = f'n.user_name = "{user_name}" AND n.id = "{id}"'
filter_clause = f'n.id = "{id}"'
return_fields = self._build_return_fields(include_embedding)
gql = f"""
MATCH (n@Memory)
Expand Down Expand Up @@ -730,16 +730,13 @@ def get_nodes(
"""
if not ids:
return []

user_name = user_name if user_name else self.config.user_name
where_user = f" AND n.user_name = '{user_name}'"
# Safe formatting of the ID list
id_list = ",".join(f'"{_id}"' for _id in ids)

return_fields = self._build_return_fields(include_embedding)
query = f"""
MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
WHERE n.id IN [{id_list}] {where_user}
WHERE n.id IN [{id_list}]
RETURN {return_fields}
"""
nodes = []
Expand Down Expand Up @@ -1497,10 +1494,10 @@ def _ensure_space_exists(cls, tmp_client, cfg):
return

try:
res = tmp_client.execute("SHOW GRAPHS;")
res = tmp_client.execute("SHOW GRAPHS")
existing = {row.values()[0].as_string() for row in res}
if db_name not in existing:
tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type;")
tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type")
logger.info(f"✅ Graph `{db_name}` created before session binding.")
else:
logger.debug(f"Graph `{db_name}` already exists.")
Expand Down
1 change: 1 addition & 0 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def remove_oldest_memory(
Args:
memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
keep_latest (int): Number of latest WorkingMemory entries to keep.
user_name(str): optional user_name.
"""
user_name = user_name if user_name else self.config.user_name
query = f"""
Expand Down
Loading