Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
acf65f4
feat: merge dev into main and change api-server (#376)
fridayL Oct 22, 2025
db432ad
feat: update log context
Oct 22, 2025
9502acc
feat: update log context
Oct 22, 2025
d74e628
feat: update mcp
Oct 23, 2025
32b2ac1
feat: update mcp
Oct 23, 2025
e4c6b92
feat: add error log
Oct 23, 2025
c27bd61
feat: add error log
Oct 23, 2025
6769b4c
feat: add error log
Oct 23, 2025
01547e1
feat: update log
Oct 24, 2025
a19584f
feat: add chat_time
Oct 24, 2025
8dfa338
feat: add chat_time
Oct 24, 2025
a91e3e2
feat: add chat_time
Oct 24, 2025
5b962e2
feat: update log
Oct 24, 2025
69a6e9a
feat: update log
Oct 24, 2025
d325a31
feat: update log
Oct 24, 2025
f0e5f5c
feat: update log
Oct 24, 2025
7fc8c05
feat: update log
Oct 24, 2025
185ed93
feat: add arms
Oct 26, 2025
f641b70
feat: add arms
Oct 26, 2025
d5c59a0
fix: format
Oct 26, 2025
b144470
fix: format
Oct 26, 2025
33921b7
feat: add dockerfile
Oct 26, 2025
49a9079
feat: add dockerfile
Oct 26, 2025
27c49b6
feat: add arms config
Oct 26, 2025
60c5dd8
feat: update log
Oct 26, 2025
3096321
feat: add sleep time
Oct 26, 2025
204efef
feat: add sleep time
Oct 26, 2025
e2c9cbf
fix: conflict
Oct 28, 2025
33a41e8
feat: update log
Oct 28, 2025
cf23174
feat: delete dockerfile
Oct 28, 2025
18e2eda
feat: delete dockerfile
Oct 28, 2025
f9a18a5
feat: update dockerfile
Oct 28, 2025
399e200
fix: conflict
Oct 28, 2025
1d4f3d1
fix: conflict
Oct 28, 2025
92be50b
feat: replace ThreadPool to context
Oct 28, 2025
8a1fd64
feat: add timed log
Oct 28, 2025
7d7f731
fix: conflict
Oct 28, 2025
eac3954
Merge branch 'dev' into feat/arms
fridayL Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/memos/api/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from fastapi.exceptions import HTTPException, RequestValidationError
from fastapi.requests import Request
from fastapi.responses import JSONResponse

Expand All @@ -10,9 +11,24 @@
class APIExceptionHandler:
"""Centralized exception handling for MemOS APIs."""

@staticmethod
async def validation_error_handler(request: Request, exc: RequestValidationError):
"""Handle request validation errors."""
logger.error(f"Validation error: {exc.errors()}")
return JSONResponse(
status_code=422,
content={
"code": 422,
"message": "Parameter validation error",
"detail": exc.errors(),
"data": None,
},
)

@staticmethod
async def value_error_handler(request: Request, exc: ValueError):
"""Handle ValueError exceptions globally."""
logger.error(f"ValueError: {exc}")
return JSONResponse(
status_code=400,
content={"code": 400, "message": str(exc), "data": None},
Expand All @@ -21,8 +37,17 @@ async def value_error_handler(request: Request, exc: ValueError):
@staticmethod
async def global_exception_handler(request: Request, exc: Exception):
"""Handle all unhandled exceptions globally."""
logger.exception("Unhandled error:")
logger.error(f"Exception: {exc}")
return JSONResponse(
status_code=500,
content={"code": 500, "message": str(exc), "data": None},
)

@staticmethod
async def http_error_handler(request: Request, exc: HTTPException):
"""Handle HTTP exceptions globally."""
logger.error(f"HTTP error {exc.status_code}: {exc.detail}")
return JSONResponse(
status_code=exc.status_code,
content={"code": exc.status_code, "message": str(exc.detail), "data": None},
)
41 changes: 32 additions & 9 deletions src/memos/api/middleware/request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Request context middleware for automatic trace_id injection.
"""

import time

from collections.abc import Callable

from starlette.middleware.base import BaseHTTPMiddleware
Expand Down Expand Up @@ -38,8 +40,19 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Extract or generate trace_id
trace_id = extract_trace_id_from_headers(request) or generate_trace_id()

env = request.headers.get("x-env")
user_type = request.headers.get("x-user-type")
user_name = request.headers.get("x-user-name")
start_time = time.time()

# Create and set request context
context = RequestContext(trace_id=trace_id, api_path=request.url.path)
context = RequestContext(
trace_id=trace_id,
api_path=request.url.path,
env=env,
user_type=user_type,
user_name=user_name,
)
set_request_context(context)

# Log request start with parameters
Expand All @@ -49,15 +62,25 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
if request.query_params:
params_log["query_params"] = dict(request.query_params)

logger.info(f"Request started: {request.method} {request.url.path}, {params_log}")
logger.info(f"Request started, params: {params_log}, headers: {request.headers}")

# Process the request
response = await call_next(request)

# Log request completion with output
logger.info(f"Request completed: {request.url.path}, status: {response.status_code}")

# Add trace_id to response headers for debugging
response.headers["x-trace-id"] = trace_id
try:
response = await call_next(request)
end_time = time.time()
if response.status_code == 200:
logger.info(
f"Request completed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
)
else:
logger.error(
f"Request Failed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
)
except Exception as e:
end_time = time.time()
logger.error(
f"Request Exception Error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms"
)
raise e

return response
6 changes: 3 additions & 3 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import traceback

from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any

from fastapi import APIRouter, HTTPException
Expand All @@ -22,6 +21,7 @@
from memos.configs.mem_scheduler import SchedulerConfigFactory
from memos.configs.reranker import RerankerConfigFactory
from memos.configs.vec_db import VectorDBConfigFactory
from memos.context.context import ContextThreadPoolExecutor
from memos.embedders.factory import EmbedderFactory
from memos.graph_dbs.factory import GraphStoreFactory
from memos.llms.factory import LLMFactory
Expand Down Expand Up @@ -370,7 +370,7 @@ def _search_pref():
)
return [_format_memory_item(data) for data in results]

with ThreadPoolExecutor(max_workers=2) as executor:
with ContextThreadPoolExecutor(max_workers=2) as executor:
text_future = executor.submit(_search_text)
pref_future = executor.submit(_search_pref)
text_formatted_memories = text_future.result()
Expand Down Expand Up @@ -532,7 +532,7 @@ def _process_pref_mem() -> list[dict[str, str]]:
for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False)
]

with ThreadPoolExecutor(max_workers=2) as executor:
with ContextThreadPoolExecutor(max_workers=2) as executor:
text_future = executor.submit(_process_text_mem)
pref_future = executor.submit(_process_pref_mem)
text_response_data = text_future.result()
Expand Down
10 changes: 8 additions & 2 deletions src/memos/api/server_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError

from memos.api.exceptions import APIExceptionHandler
from memos.api.middleware.request_context import RequestContextMiddleware
Expand All @@ -21,8 +22,13 @@
# Include routers
app.include_router(server_router)

# Exception handlers
# Request validation failed
app.exception_handler(RequestValidationError)(APIExceptionHandler.validation_error_handler)
# Invalid business code parameters
app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler)
# Business layer manual exception
app.exception_handler(HTTPException)(APIExceptionHandler.http_error_handler)
# Fallback for unknown errors
app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler)


Expand Down
96 changes: 89 additions & 7 deletions src/memos/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,19 @@ class RequestContext:
This provides a Flask g-like object for FastAPI applications.
"""

def __init__(self, trace_id: str | None = None, api_path: str | None = None):
def __init__(
self,
trace_id: str | None = None,
api_path: str | None = None,
env: str | None = None,
user_type: str | None = None,
user_name: str | None = None,
):
self.trace_id = trace_id or "trace-id"
self.api_path = api_path
self.env = env
self.user_type = user_type
self.user_name = user_name
self._data: dict[str, Any] = {}

def set(self, key: str, value: Any) -> None:
Expand All @@ -43,7 +53,13 @@ def get(self, key: str, default: Any | None = None) -> Any:
return self._data.get(key, default)

def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_") or name in ("trace_id", "api_path"):
if name.startswith("_") or name in (
"trace_id",
"api_path",
"env",
"user_type",
"user_name",
):
super().__setattr__(name, value)
else:
if not hasattr(self, "_data"):
Expand All @@ -58,7 +74,14 @@ def __getattr__(self, name: str) -> Any:

def to_dict(self) -> dict[str, Any]:
"""Convert context to dictionary."""
return {"trace_id": self.trace_id, "api_path": self.api_path, "data": self._data.copy()}
return {
"trace_id": self.trace_id,
"api_path": self.api_path,
"env": self.env,
"user_type": self.user_type,
"user_name": self.user_name,
"data": self._data.copy(),
}


def set_request_context(context: RequestContext) -> None:
Expand Down Expand Up @@ -93,6 +116,36 @@ def get_current_api_path() -> str | None:
return None


def get_current_env() -> str | None:
"""
Get the current request's env.
"""
context = _request_context.get()
if context:
return context.get("env")
return "prod"


def get_current_user_type() -> str | None:
"""
Get the current request's user type.
"""
context = _request_context.get()
if context:
return context.get("user_type")
return "opensource"


def get_current_user_name() -> str | None:
"""
Get the current request's user name.
"""
context = _request_context.get()
if context:
return context.get("user_name")
return "memos"


def get_current_context() -> RequestContext | None:
"""
Get the current request context.
Expand All @@ -103,7 +156,11 @@ def get_current_context() -> RequestContext | None:
context_dict = _request_context.get()
if context_dict:
ctx = RequestContext(
trace_id=context_dict.get("trace_id"), api_path=context_dict.get("api_path")
trace_id=context_dict.get("trace_id"),
api_path=context_dict.get("api_path"),
env=context_dict.get("env"),
user_type=context_dict.get("user_type"),
user_name=context_dict.get("user_name"),
)
ctx._data = context_dict.get("data", {}).copy()
return ctx
Expand Down Expand Up @@ -141,14 +198,21 @@ def __init__(self, target, args=(), kwargs=None, **thread_kwargs):

self.main_trace_id = get_current_trace_id()
self.main_api_path = get_current_api_path()
self.main_env = get_current_env()
self.main_user_type = get_current_user_type()
self.main_user_name = get_current_user_name()
self.main_context = get_current_context()

def run(self):
# Create a new RequestContext with the main thread's trace_id
if self.main_context:
# Copy the context data
child_context = RequestContext(
trace_id=self.main_trace_id, api_path=self.main_context.api_path
trace_id=self.main_trace_id,
api_path=self.main_api_path,
env=self.main_env,
user_type=self.main_user_type,
user_name=self.main_user_name,
)
child_context._data = self.main_context._data.copy()

Expand All @@ -171,13 +235,22 @@ def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any:
"""
main_trace_id = get_current_trace_id()
main_api_path = get_current_api_path()
main_env = get_current_env()
main_user_type = get_current_user_type()
main_user_name = get_current_user_name()
main_context = get_current_context()

@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if main_context:
# Create and set new context in worker thread
child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path)
child_context = RequestContext(
trace_id=main_trace_id,
api_path=main_api_path,
env=main_env,
user_type=main_user_type,
user_name=main_user_name,
)
child_context._data = main_context._data.copy()
set_request_context(child_context)

Expand All @@ -198,13 +271,22 @@ def map(
"""
main_trace_id = get_current_trace_id()
main_api_path = get_current_api_path()
main_env = get_current_env()
main_user_type = get_current_user_type()
main_user_name = get_current_user_name()
main_context = get_current_context()

@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if main_context:
# Create and set new context in worker thread
child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path)
child_context = RequestContext(
trace_id=main_trace_id,
api_path=main_api_path,
env=main_env,
user_type=main_user_type,
user_name=main_user_name,
)
child_context._data = main_context._data.copy()
set_request_context(child_context)

Expand Down
23 changes: 16 additions & 7 deletions src/memos/embedders/universal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

from memos.configs.embedder import UniversalAPIEmbedderConfig
from memos.embedders.base import BaseEmbedder
from memos.log import get_logger
from memos.utils import timed


logger = get_logger(__name__)


class UniversalAPIEmbedder(BaseEmbedder):
Expand All @@ -19,14 +24,18 @@ def __init__(self, config: UniversalAPIEmbedderConfig):
api_key=config.api_key,
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
raise ValueError(f"Embeddings unsupported provider: {self.provider}")

@timed(log=True, log_prefix="EmbedderAPI")
def embed(self, texts: list[str]) -> list[list[float]]:
if self.provider == "openai" or self.provider == "azure":
response = self.client.embeddings.create(
model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"),
input=texts,
)
return [r.embedding for r in response.data]
try:
response = self.client.embeddings.create(
model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"),
input=texts,
)
return [r.embedding for r in response.data]
except Exception as e:
raise Exception(f"Embeddings request ended with error: {e}") from e
else:
raise ValueError(f"Unsupported provider: {self.provider}")
raise ValueError(f"Embeddings unsupported provider: {self.provider}")
Loading