Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public AgentServiceImpl(AgentDao agentDao) {
@Override
public PageData<AgentEntity> adminAgentList(Map<String, Object> params) {
IPage<AgentEntity> page = agentDao.selectPage(
getPage(params, "sort", true),
getPage(params, "agent_name", true),
new QueryWrapper<>());
return new PageData<>(page.getRecords(), page.getTotal());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,9 @@ private void buildModuleConfig(
boolean isCache) {
Map<String, String> selectedModule = new HashMap<>();

String[] modelTypes = { "VAD", "ASR", "LLM", "TTS", "Memory", "Intent" };
String[] modelIds = { vadModelId, asrModelId, llmModelId, ttsModelId, memModelId, intentModelId };
String[] modelTypes = { "VAD", "ASR", "TTS", "Memory", "Intent", "LLM" };
String[] modelIds = { vadModelId, asrModelId, ttsModelId, memModelId, intentModelId, llmModelId };
String intentLLMModelId = null;

for (int i = 0; i < modelIds.length; i++) {
if (modelIds[i] == null) {
Expand All @@ -246,10 +247,27 @@ private void buildModuleConfig(
if ("TTS".equals(modelTypes[i]) && voice != null) {
((Map<String, Object>) model.getConfigJson()).put("private_voice", voice);
}
// 如果是Intent类型,且type=intent_llm,则给他添加附加模型
if ("Intent".equals(modelTypes[i])) {
Map<String, Object> map = (Map<String, Object>) model.getConfigJson();
if ("intent_llm".equals(map.get("type"))) {
intentLLMModelId = (String) map.get("llm");
if (intentLLMModelId != null && intentLLMModelId.equals(llmModelId)) {
intentLLMModelId = null;
}
}
}
// 如果是LLM类型,且intentLLMModelId不为空,则添加附加模型
if ("LLM".equals(modelTypes[i]) && intentLLMModelId != null) {
ModelConfigEntity intentLLM = modelConfigService.getModelById(intentLLMModelId, isCache);
typeConfig.put(intentLLM.getId(), intentLLM.getConfigJson());
}
}
result.put(modelTypes[i], typeConfig);

selectedModule.put(modelTypes[i], model.getId());
}

result.put("selected_module", selectedModule);
if (StringUtils.isNotBlank(prompt)) {
prompt = prompt.replace("{{assistant_name}}", "小智");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public PageData<UserShowDeviceListVO> page(DevicePageUserDTO dto) {
params.put(Constant.PAGE, dto.getPage());
params.put(Constant.LIMIT, dto.getLimit());
IPage<DeviceEntity> page = baseDao.selectPage(
getPage(params, "sort", true),
getPage(params, "mac_address", true),
// 定义查询条件
new QueryWrapper<DeviceEntity>()
// 必须设备关键词查找
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class SysParamsServiceImpl extends BaseServiceImpl<SysParamsDao, SysParam
@Override
public PageData<SysParamsDTO> page(Map<String, Object> params) {
IPage<SysParamsEntity> page = baseDao.selectPage(
getPage(params, Constant.CREATE_DATE, false),
getPage(params, null, false),
getWrapper(params));

return getPageData(page, SysParamsDTO.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public PageData<TimbreDetailsVO> page(TimbrePageDTO dto) {
params.put(Constant.PAGE, dto.getPage());
params.put(Constant.LIMIT, dto.getLimit());
IPage<TimbreEntity> page = baseDao.selectPage(
getPage(params, "sort", true),
getPage(params, null, true),
// 定义查询条件
new QueryWrapper<TimbreEntity>()
// 必须按照ttsID查找
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-- 对0.3.0版本之前的参数进行修改
update `sys_params` set param_value = '.mp3;.wav;.p3' where param_code = 'plugins.play_music.music_ext';
update `ai_model_config` set config_json = '{\"type\": \"intent_llm\", \"llm\": \"LLM_ChatGLMLLM\"}' where id = 'Intent_intent_llm';
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,11 @@ databaseChangeLog:
changes:
- sqlFile:
encoding: utf8
path: classpath:db/changelog/202504112058.sql
path: classpath:db/changelog/202504112058.sql
- changeSet:
id: 202504131542
author: John
changes:
- sqlFile:
encoding: utf8
path: classpath:db/changelog/202504131542.sql
27 changes: 17 additions & 10 deletions main/xiaozhi-server/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ def __init__(

self.close_after_chat = False # 是否在聊天结束后关闭连接
self.use_function_call_mode = False
if self.config["selected_module"]["Intent"] == "function_call":
self.use_function_call_mode = True

async def handle_connection(self, ws):
try:
Expand Down Expand Up @@ -222,37 +220,37 @@ def _initialize_models(self):
)
if private_config.get("VAD", None) is not None:
init_vad = True
self.config["vad"] = private_config["VAD"]
self.config["VAD"] = private_config["VAD"]
self.config["selected_module"]["VAD"] = private_config["selected_module"][
"VAD"
]
if private_config.get("ASR", None) is not None:
init_asr = True
self.config["asr"] = private_config["ASR"]
self.config["ASR"] = private_config["ASR"]
self.config["selected_module"]["ASR"] = private_config["selected_module"][
"ASR"
]
if private_config.get("LLM", None) is not None:
init_llm = True
self.config["llm"] = private_config["LLM"]
self.config["LLM"] = private_config["LLM"]
self.config["selected_module"]["LLM"] = private_config["selected_module"][
"LLM"
]
if private_config.get("TTS", None) is not None:
init_tts = True
self.config["tts"] = private_config["TTS"]
self.config["TTS"] = private_config["TTS"]
self.config["selected_module"]["TTS"] = private_config["selected_module"][
"TTS"
]
if private_config.get("Memory", None) is not None:
init_memory = True
self.config["memory"] = private_config["Memory"]
self.config["Memory"] = private_config["Memory"]
self.config["selected_module"]["Memory"] = private_config[
"selected_module"
]["Memory"]
if private_config.get("Intent", None) is not None:
init_intent = True
self.config["intent"] = private_config["Intent"]
self.config["Intent"] = private_config["Intent"]
self.config["selected_module"]["Intent"] = private_config[
"selected_module"
]["Intent"]
Expand Down Expand Up @@ -287,17 +285,26 @@ def _initialize_memory(self):
self.memory.init_memory(device_id, self.llm)

def _initialize_intent(self):
if (
self.config["Intent"][self.config["selected_module"]["Intent"]]["type"]
== "function_call"
):
self.use_function_call_mode = True
"""初始化意图识别模块"""
# 获取意图识别配置
intent_config = self.config["Intent"]
intent_type = self.config["selected_module"]["Intent"]
intent_type = self.config["Intent"][self.config["selected_module"]["Intent"]][
"type"
]

# 如果使用 nointent,直接返回
if intent_type == "nointent":
return
# 使用 intent_llm 模式
elif intent_type == "intent_llm":
intent_llm_name = intent_config["intent_llm"]["llm"]
intent_llm_name = intent_config[self.config["selected_module"]["Intent"]][
"llm"
]

if intent_llm_name and intent_llm_name in self.config["LLM"]:
# 如果配置了专用LLM,则创建独立的LLM实例
Expand Down
4 changes: 3 additions & 1 deletion main/xiaozhi-server/core/handle/functionHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def register_nessary_functions(self):

def register_config_functions(self):
"""注册配置中的函数,可以不同客户端使用不同的配置"""
for func in self.config["Intent"]["function_call"].get("functions", []):
for func in self.config["Intent"][self.config["selected_module"]["Intent"]].get(
"functions", []
):
self.function_registry.register_function(func)

"""home assistant需要初始化提示词"""
Expand Down
5 changes: 5 additions & 0 deletions main/xiaozhi-server/core/handle/intentHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ async def process_intent_result(conn, intent_result, original_text):
if function_name == "continue_chat":
return False

if function_name == "play_music":
funcItem = conn.func_handler.get_function(function_name)
if not funcItem:
conn.func_handler.function_registry.register_function("play_music")

function_args = None
if "arguments" in intent_data["function_call"]:
function_args = intent_data["function_call"]["arguments"]
Expand Down
4 changes: 1 addition & 3 deletions main/xiaozhi-server/core/handle/iotHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,12 @@ def register_device_type(descriptor):

# 用于接受前端设备推送的搜索iot描述
async def handleIotDescriptors(conn, descriptors):
if not conn.use_function_call_mode:
return
wait_max_time = 5
while conn.func_handler is None or not conn.func_handler.finish_init:
await asyncio.sleep(1)
wait_max_time -= 1
if wait_max_time <= 0:
logger.bind(tag=TAG).error("连接对象没有func_handler")
logger.bind(tag=TAG).debug("连接对象没有func_handler")
return
"""处理物联网描述"""
functions_changed = False
Expand Down
10 changes: 9 additions & 1 deletion main/xiaozhi-server/core/providers/llm/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ def __init__(self, config):
self.base_url = config.get("base_url")
else:
self.base_url = config.get("url")
self.max_tokens = config.get("max_tokens", 500)
max_tokens = config.get("max_tokens")
if max_tokens is None or max_tokens == "":
max_tokens = 500

try:
max_tokens = int(max_tokens)
except (ValueError, TypeError):
max_tokens = 500
self.max_tokens = max_tokens

check_model_key("LLM", self.api_key)
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
Expand Down
3 changes: 2 additions & 1 deletion main/xiaozhi-server/core/providers/vad/silero.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def is_vad(self, conn, opus_packet):
audio_tensor = torch.from_numpy(audio_float32)

# 检测语音活动
speech_prob = self.model(audio_tensor, 16000).item()
with torch.no_grad():
speech_prob = self.model(audio_tensor, 16000).item()
client_have_voice = speech_prob >= self.vad_threshold

# 如果之前有声音,但本次没有声音,且与上次有声音的时间查已经超过了静默阈值,则认为已经说完一句话
Expand Down
49 changes: 29 additions & 20 deletions main/xiaozhi-server/plugins_func/functions/handle_exit_intent.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,43 @@
from plugins_func.register import register_function,ToolType, ActionResponse, Action
from plugins_func.register import register_function, ToolType, ActionResponse, Action
from config.logger import setup_logging

TAG = __name__
logger = setup_logging()

handle_exit_intent_function_desc = {
"type": "function",
"function": {
"name": "handle_exit_intent",
"description": "当用户想结束对话或需要退出系统时调用",
"parameters": {
"type": "object",
"properties": {
"say_goodbye": {
"type": "string",
"description": "和用户友好结束对话的告别语"
}
},
"required": ["say_goodbye"]
}
"type": "function",
"function": {
"name": "handle_exit_intent",
"description": "当用户想结束对话或需要退出系统时调用",
"parameters": {
"type": "object",
"properties": {
"say_goodbye": {
"type": "string",
"description": "和用户友好结束对话的告别语",
}
}
},
"required": ["say_goodbye"],
},
},
}

@register_function('handle_exit_intent', handle_exit_intent_function_desc, ToolType.SYSTEM_CTL)
def handle_exit_intent(conn, say_goodbye: str):

@register_function(
"handle_exit_intent", handle_exit_intent_function_desc, ToolType.SYSTEM_CTL
)
def handle_exit_intent(conn, say_goodbye: str | None = None):
# 处理退出意图
try:
if say_goodbye is None:
say_goodbye = "再见,祝您生活愉快!"
conn.close_after_chat = True
logger.bind(tag=TAG).info(f"退出意图已处理:{say_goodbye}")
return ActionResponse(action=Action.RESPONSE, result="退出意图已处理", response=say_goodbye)
return ActionResponse(
action=Action.RESPONSE, result="退出意图已处理", response=say_goodbye
)
except Exception as e:
logger.bind(tag=TAG).error(f"处理退出意图错误: {e}")
return ActionResponse(action=Action.NONE, result="退出意图处理失败", response="")
return ActionResponse(
action=Action.NONE, result="退出意图处理失败", response=""
)
8 changes: 6 additions & 2 deletions main/xiaozhi-server/plugins_func/functions/hass_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

def append_devices_to_prompt(conn):
if conn.use_function_call_mode:
funcs = conn.config["Intent"]["function_call"].get("functions", [])
funcs = conn.config["Intent"][conn.config["selected_module"]["Intent"]].get(
"functions", []
)
if "hass_get_state" in funcs or "hass_set_state" in funcs:
prompt = "下面是我家智能设备,可以通过homeassistant控制\n"
devices = conn.config["plugins"]["home_assistant"].get("devices", [])
Expand All @@ -26,7 +28,9 @@ def initialize_hass_handler(conn):
global HASS_CACHE
if HASS_CACHE == {}:
if conn.use_function_call_mode:
funcs = conn.config["Intent"]["function_call"].get("functions", [])
funcs = conn.config["Intent"][conn.config["selected_module"]["Intent"]].get(
"functions", []
)
if "hass_get_state" in funcs or "hass_set_state" in funcs:
HASS_CACHE["base_url"] = conn.config["plugins"]["home_assistant"].get(
"base_url"
Expand Down