1+ import json
12import os
3+ import time
24import traceback
35
6+ from datetime import datetime
47from typing import TYPE_CHECKING , Any
58
69from fastapi import APIRouter , HTTPException
10+ from fastapi .responses import StreamingResponse
711
812from memos .api .config import APIConfig
913from memos .api .product_models import (
3236from memos .mem_scheduler .orm_modules .base_model import BaseDBManager
3337from memos .mem_scheduler .scheduler_factory import SchedulerFactory
3438from memos .mem_scheduler .schemas .general_schemas import (
39+ ADD_LABEL ,
40+ MEM_READ_LABEL ,
41+ PREF_ADD_LABEL ,
3542 SearchMode ,
3643)
44+ from memos .mem_scheduler .schemas .message_schemas import ScheduleMessageItem
3745from memos .memories .textual .prefer_text_memory .config import (
3846 AdderConfigFactory ,
3947 ExtractorConfigFactory ,
@@ -233,6 +241,7 @@ def init_server():
233241 chat_llm = llm ,
234242 process_llm = mem_reader .llm ,
235243 db_engine = BaseDBManager .create_default_sqlite_engine (),
244+ mem_reader = mem_reader ,
236245 )
237246 mem_scheduler .current_mem_cube = naive_mem_cube
238247 mem_scheduler .start ()
@@ -477,6 +486,13 @@ def add_memories(add_req: APIADDRequest):
477486 if not target_session_id :
478487 target_session_id = "default_session"
479488
489+ # If text memory backend works in async mode, submit tasks to scheduler
490+ try :
491+ sync_mode = getattr (naive_mem_cube .text_mem , "mode" , "sync" )
492+ except Exception :
493+ sync_mode = "sync"
494+ logger .info (f"Add sync_mode mode is: { sync_mode } " )
495+
480496 def _process_text_mem () -> list [dict [str , str ]]:
481497 memories_local = mem_reader .get_memory (
482498 [add_req .messages ],
@@ -485,6 +501,7 @@ def _process_text_mem() -> list[dict[str, str]]:
485501 "user_id" : add_req .user_id ,
486502 "session_id" : target_session_id ,
487503 },
504+ mode = "fast" if sync_mode == "async" else "fine" ,
488505 )
489506 flattened_local = [mm for m in memories_local for mm in m ]
490507 logger .info (f"Memory extraction completed for user { add_req .user_id } " )
@@ -496,6 +513,34 @@ def _process_text_mem() -> list[dict[str, str]]:
496513 f"Added { len (mem_ids_local )} memories for user { add_req .user_id } "
497514 f"in session { add_req .session_id } : { mem_ids_local } "
498515 )
516+ if sync_mode == "async" :
517+ try :
518+ message_item_read = ScheduleMessageItem (
519+ user_id = add_req .user_id ,
520+ session_id = target_session_id ,
521+ mem_cube_id = add_req .mem_cube_id ,
522+ mem_cube = naive_mem_cube ,
523+ label = MEM_READ_LABEL ,
524+ content = json .dumps (mem_ids_local ),
525+ timestamp = datetime .utcnow (),
526+ user_name = add_req .mem_cube_id ,
527+ )
528+ mem_scheduler .submit_messages (messages = [message_item_read ])
529+ logger .info (f"2105Submit messages!!!!!: { json .dumps (mem_ids_local )} " )
530+ except Exception as e :
531+ logger .error (f"Failed to submit async memory tasks: { e } " , exc_info = True )
532+ else :
533+ message_item_add = ScheduleMessageItem (
534+ user_id = add_req .user_id ,
535+ session_id = target_session_id ,
536+ mem_cube_id = add_req .mem_cube_id ,
537+ mem_cube = naive_mem_cube ,
538+ label = ADD_LABEL ,
539+ content = json .dumps (mem_ids_local ),
540+ timestamp = datetime .utcnow (),
541+ user_name = add_req .mem_cube_id ,
542+ )
543+ mem_scheduler .submit_messages (messages = [message_item_add ])
499544 return [
500545 {
501546 "memory" : memory .memory ,
@@ -508,27 +553,46 @@ def _process_text_mem() -> list[dict[str, str]]:
508553 def _process_pref_mem () -> list [dict [str , str ]]:
509554 if os .getenv ("ENABLE_PREFERENCE_MEMORY" , "false" ).lower () != "true" :
510555 return []
511- pref_memories_local = naive_mem_cube .pref_mem .get_memory (
512- [add_req .messages ],
513- type = "chat" ,
514- info = {
515- "user_id" : add_req .user_id ,
516- "session_id" : target_session_id ,
517- },
518- )
519- pref_ids_local : list [str ] = naive_mem_cube .pref_mem .add (pref_memories_local )
520- logger .info (
521- f"Added { len (pref_ids_local )} preferences for user { add_req .user_id } "
522- f"in session { add_req .session_id } : { pref_ids_local } "
523- )
524- return [
525- {
526- "memory" : memory .memory ,
527- "memory_id" : memory_id ,
528- "memory_type" : memory .metadata .preference_type ,
529- }
530- for memory_id , memory in zip (pref_ids_local , pref_memories_local , strict = False )
531- ]
556+ # Follow async behavior similar to core.py: enqueue when async
557+ if sync_mode == "async" :
558+ try :
559+ messages_list = [add_req .messages ]
560+ message_item_pref = ScheduleMessageItem (
561+ user_id = add_req .user_id ,
562+ session_id = target_session_id ,
563+ mem_cube_id = add_req .mem_cube_id ,
564+ mem_cube = naive_mem_cube ,
565+ label = PREF_ADD_LABEL ,
566+ content = json .dumps (messages_list ),
567+ timestamp = datetime .utcnow (),
568+ )
569+ mem_scheduler .submit_messages (messages = [message_item_pref ])
570+ logger .info ("Submitted preference add to scheduler (async mode)" )
571+ except Exception as e :
572+ logger .error (f"Failed to submit PREF_ADD task: { e } " , exc_info = True )
573+ return []
574+ else :
575+ pref_memories_local = naive_mem_cube .pref_mem .get_memory (
576+ [add_req .messages ],
577+ type = "chat" ,
578+ info = {
579+ "user_id" : add_req .user_id ,
580+ "session_id" : target_session_id ,
581+ },
582+ )
583+ pref_ids_local : list [str ] = naive_mem_cube .pref_mem .add (pref_memories_local )
584+ logger .info (
585+ f"Added { len (pref_ids_local )} preferences for user { add_req .user_id } "
586+ f"in session { add_req .session_id } : { pref_ids_local } "
587+ )
588+ return [
589+ {
590+ "memory" : memory .memory ,
591+ "memory_id" : memory_id ,
592+ "memory_type" : memory .metadata .preference_type ,
593+ }
594+ for memory_id , memory in zip (pref_ids_local , pref_memories_local , strict = False )
595+ ]
532596
533597 with ContextThreadPoolExecutor (max_workers = 2 ) as executor :
534598 text_future = executor .submit (_process_text_mem )
@@ -542,6 +606,155 @@ def _process_pref_mem() -> list[dict[str, str]]:
542606 )
543607
544608
609+ @router .get ("/scheduler/status" , summary = "Get scheduler running task count" )
610+ def scheduler_status ():
611+ """
612+ Return current running tasks count from scheduler dispatcher.
613+ Shape is consistent with /scheduler/wait.
614+ """
615+ try :
616+ running = mem_scheduler .dispatcher .get_running_tasks ()
617+ running_count = len (running )
618+ now_ts = time .time ()
619+
620+ return {
621+ "message" : "ok" ,
622+ "data" : {
623+ "running_tasks" : running_count ,
624+ "timestamp" : now_ts ,
625+ },
626+ }
627+
628+ except Exception as err :
629+ logger .error ("Failed to get scheduler status: %s" , traceback .format_exc ())
630+
631+ raise HTTPException (status_code = 500 , detail = "Failed to get scheduler status" ) from err
632+
633+
634+ @router .post ("/scheduler/wait" , summary = "Wait until scheduler is idle" )
635+ def scheduler_wait (timeout_seconds : float = 120.0 , poll_interval : float = 0.2 ):
636+ """
637+ Block until scheduler has no running tasks, or timeout.
638+ We return a consistent structured payload so callers can
639+ tell whether this was a clean flush or a timeout.
640+
641+ Args:
642+ timeout_seconds: max seconds to wait
643+ poll_interval: seconds between polls
644+ """
645+ start = time .time ()
646+ try :
647+ while True :
648+ running = mem_scheduler .dispatcher .get_running_tasks ()
649+ running_count = len (running )
650+ elapsed = time .time () - start
651+
652+ # success -> scheduler is idle
653+ if running_count == 0 :
654+ return {
655+ "message" : "idle" ,
656+ "data" : {
657+ "running_tasks" : 0 ,
658+ "waited_seconds" : round (elapsed , 3 ),
659+ "timed_out" : False ,
660+ },
661+ }
662+
663+ # timeout check
664+ if elapsed > timeout_seconds :
665+ return {
666+ "message" : "timeout" ,
667+ "data" : {
668+ "running_tasks" : running_count ,
669+ "waited_seconds" : round (elapsed , 3 ),
670+ "timed_out" : True ,
671+ },
672+ }
673+
674+ time .sleep (poll_interval )
675+
676+ except Exception as err :
677+ logger .error (
678+ "Failed while waiting for scheduler: %s" ,
679+ traceback .format_exc (),
680+ )
681+ raise HTTPException (
682+ status_code = 500 ,
683+ detail = "Failed while waiting for scheduler" ,
684+ ) from err
685+
686+
687+ @router .get ("/scheduler/wait/stream" , summary = "Stream scheduler progress (SSE)" )
688+ def scheduler_wait_stream (timeout_seconds : float = 120.0 , poll_interval : float = 0.2 ):
689+ """
690+ Stream scheduler progress via Server-Sent Events (SSE).
691+
692+ Contract:
693+ - We emit periodic heartbeat frames while tasks are still running.
694+ - Each heartbeat frame is JSON, prefixed with "data: ".
695+ - On final frame, we include status = "idle" or "timeout" and timed_out flag,
696+ with the same semantics as /scheduler/wait.
697+
698+ Example curl:
699+ curl -N "${API_HOST}/product/scheduler/wait/stream?timeout_seconds=10&poll_interval=0.5"
700+ """
701+
702+ def event_generator ():
703+ start = time .time ()
704+ try :
705+ while True :
706+ running = mem_scheduler .dispatcher .get_running_tasks ()
707+ running_count = len (running )
708+ elapsed = time .time () - start
709+
710+ # heartbeat frame
711+ heartbeat_payload = {
712+ "running_tasks" : running_count ,
713+ "elapsed_seconds" : round (elapsed , 3 ),
714+ "status" : "running" if running_count > 0 else "idle" ,
715+ }
716+ yield "data: " + json .dumps (heartbeat_payload , ensure_ascii = False ) + "\n \n "
717+
718+ # scheduler is idle -> final frame + break
719+ if running_count == 0 :
720+ final_payload = {
721+ "running_tasks" : 0 ,
722+ "elapsed_seconds" : round (elapsed , 3 ),
723+ "status" : "idle" ,
724+ "timed_out" : False ,
725+ }
726+ yield "data: " + json .dumps (final_payload , ensure_ascii = False ) + "\n \n "
727+ break
728+
729+ # timeout -> final frame + break
730+ if elapsed > timeout_seconds :
731+ final_payload = {
732+ "running_tasks" : running_count ,
733+ "elapsed_seconds" : round (elapsed , 3 ),
734+ "status" : "timeout" ,
735+ "timed_out" : True ,
736+ }
737+ yield "data: " + json .dumps (final_payload , ensure_ascii = False ) + "\n \n "
738+ break
739+
740+ time .sleep (poll_interval )
741+
742+ except Exception as e :
743+ err_payload = {
744+ "status" : "error" ,
745+ "detail" : "stream_failed" ,
746+ "exception" : str (e ),
747+ }
748+ logger .error (
749+ "Failed streaming scheduler wait: %s: %s" ,
750+ e ,
751+ traceback .format_exc (),
752+ )
753+ yield "data: " + json .dumps (err_payload , ensure_ascii = False ) + "\n \n "
754+
755+ return StreamingResponse (event_generator (), media_type = "text/event-stream" )
756+
757+
545758@router .post ("/chat/complete" , summary = "Chat with MemOS (Complete Response)" )
546759def chat_complete (chat_req : APIChatCompleteRequest ):
547760 """Chat with MemOS for a specific user. Returns complete response (non-streaming)."""
0 commit comments