Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,13 @@ public class OTAController {
public ResponseEntity<String> checkOTAVersion(
@RequestBody DeviceReportReqDTO deviceReportReqDTO,
@Parameter(name = "Device-Id", description = "设备唯一标识", required = true, in = ParameterIn.HEADER) @RequestHeader("Device-Id") String deviceId,
@Parameter(name = "Client-Id", description = "客户端标识", required = true, in = ParameterIn.HEADER) @RequestHeader("Client-Id") String clientId) {
if (StringUtils.isAnyBlank(deviceId, clientId)) {
@Parameter(name = "Client-Id", description = "客户端标识", required = false, in = ParameterIn.HEADER) @RequestHeader(value = "Client-Id", required = false) String clientId) {
if (StringUtils.isBlank(deviceId)) {
return createResponse(DeviceReportRespDTO.createError("Device ID is required"));
}
if (StringUtils.isBlank(clientId)) {
clientId = deviceId;
}
String macAddress = deviceReportReqDTO.getMacAddress();
boolean macAddressValid = NetworkUtil.isMacAddressValid(macAddress);
// 设备Id和Mac地址应是一致的, 并且必须需要application字段
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public Result<?> changePassword(@RequestBody PasswordDTO passwordDTO) {
@Operation(summary = "公共配置")
public Result<Map<String, Object>> pubConfig() {
Map<String, Object> config = new HashMap<>();
config.put("version", "0.3.4");
config.put("version", "0.3.5");
config.put("allowUserRegister", sysUserService.getAllowUserRegister());
return new Result<Map<String, Object>>().ok(config);
}
Expand Down
2 changes: 1 addition & 1 deletion main/xiaozhi-server/config/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from loguru import logger
from config.config_loader import load_config

SERVER_VERSION = "0.3.4"
SERVER_VERSION = "0.3.5"


def get_module_abbreviation(module_name, module_dict):
Expand Down
22 changes: 21 additions & 1 deletion main/xiaozhi-server/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,27 @@ async def handle_connection(self, ws):
try:
# 获取并验证headers
self.headers = dict(ws.request.headers)

if self.headers.get("device-id", None) is None:
# 尝试从 URL 的查询参数中获取 device-id
from urllib.parse import parse_qs, urlparse

# 从 WebSocket 请求中获取路径
request_path = ws.request.path
if not request_path:
self.logger.bind(tag=TAG).error("无法获取请求路径")
return
parsed_url = urlparse(request_path)
query_params = parse_qs(parsed_url.query)
if "device-id" in query_params:
self.headers["device-id"] = query_params["device-id"][0]
self.headers["client-id"] = query_params["client-id"][0]
else:
self.logger.bind(tag=TAG).error(
"无法从请求头和URL查询参数中获取device-id"
)
return

# 获取客户端ip地址
self.client_ip = ws.remote_address[0]
self.logger.bind(tag=TAG).info(
Expand Down Expand Up @@ -184,7 +205,6 @@ def _initialize_components(self):
self._initialize_models()

"""加载提示词"""
self.prompt = self.config["prompt"]
self.dialogue.put(Message(role="system", content=self.prompt))

"""加载记忆"""
Expand Down
2 changes: 2 additions & 0 deletions main/xiaozhi-server/core/handle/helloHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ async def wakeupWordsResponse(conn):
"""唤醒词响应"""
wakeup_word = random.choice(WAKEUP_CONFIG["words"])
result = conn.llm.response_no_stream(conn.config["prompt"], wakeup_word)
if result is None or result == "":
return
tts_file = await asyncio.to_thread(conn.tts.to_tts, result)

if tts_file is not None and os.path.exists(tts_file):
Expand Down
52 changes: 27 additions & 25 deletions main/xiaozhi-server/core/handle/iotHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,33 +144,35 @@ def __init__(self, name, description, properties, methods):
self.methods = []

# 根据描述创建属性
for key, value in properties.items():
property_item = globals()[key] = {}
property_item["name"] = key
property_item["description"] = value["description"]
if value["type"] == "number":
property_item["value"] = 0
elif value["type"] == "boolean":
property_item["value"] = False
else:
property_item["value"] = ""
self.properties.append(property_item)
if properties is not None:
for key, value in properties.items():
property_item = globals()[key] = {}
property_item["name"] = key
property_item["description"] = value["description"]
if value["type"] == "number":
property_item["value"] = 0
elif value["type"] == "boolean":
property_item["value"] = False
else:
property_item["value"] = ""
self.properties.append(property_item)

# 根据描述创建方法
for key, value in methods.items():
method = globals()[key] = {}
method["description"] = value["description"]
method["name"] = key
for k, v in value["parameters"].items():
method[k] = {}
method[k]["description"] = v["description"]
if v["type"] == "number":
method[k]["value"] = 0
elif v["type"] == "boolean":
method[k]["value"] = False
else:
method[k]["value"] = ""
self.methods.append(method)
if methods is not None:
for key, value in methods.items():
method = globals()[key] = {}
method["description"] = value["description"]
method["name"] = key
for k, v in value["parameters"].items():
method[k] = {}
method[k]["description"] = v["description"]
if v["type"] == "number":
method[k]["value"] = 0
elif v["type"] == "boolean":
method[k]["value"] = False
else:
method[k]["value"] = ""
self.methods.append(method)


def register_device_type(descriptor):
Expand Down
2 changes: 1 addition & 1 deletion main/xiaozhi-server/core/handle/receiveAudioHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def check_bind_device(conn):
digit = conn.bind_code[i]
num_path = f"config/assets/bind_code/{digit}.wav"
num_packets, _ = conn.tts.audio_to_opus_data(num_path)
conn.audio_play_queue.put((num_packets, text, i + 1))
conn.audio_play_queue.put((num_packets, None, i + 1))
except Exception as e:
logger.bind(tag=TAG).error(f"播放数字音频失败: {e}")
continue
Expand Down
1 change: 0 additions & 1 deletion main/xiaozhi-server/core/handle/sendAudioHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import time
from core.utils.util import (
remove_punctuation_and_length,
get_string_no_punctuation_or_emoji,
)

Expand Down
23 changes: 14 additions & 9 deletions main/xiaozhi-server/core/providers/llm/coze/coze.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@
import re
from core.providers.llm.base import LLMProviderBase
import os

# official coze sdk for Python [cozepy](https://github.com/coze-dev/coze-py)
from cozepy import COZE_CN_BASE_URL
from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageContentType, ChatEventType # noqa
from cozepy import (
Coze,
TokenAuth,
Message,
ChatStatus,
MessageContentType,
ChatEventType,
) # noqa

TAG = __name__
logger = setup_logging()
Expand All @@ -15,25 +23,22 @@
class LLMProvider(LLMProviderBase):
def __init__(self, config):
self.personal_access_token = config.get("personal_access_token")
self.bot_id = config.get("bot_id")
self.user_id = config.get("user_id")
self.bot_id = str(config.get("bot_id"))
self.user_id = str(config.get("user_id"))
self.session_conversation_map = {} # 存储session_id和conversation_id的映射

def response(self, session_id, dialogue):
coze_api_token = self.personal_access_token
coze_api_base = COZE_CN_BASE_URL

last_msg = next(m for m in reversed(dialogue) if m["role"] == "user")

coze = Coze(auth=TokenAuth(token=coze_api_token), base_url=coze_api_base)
conversation_id = self.session_conversation_map.get(session_id)

# 如果没有找到conversation_id,则创建新的对话
if not conversation_id:
conversation = coze.conversations.create(
messages=[
]
)
conversation = coze.conversations.create(messages=[])
conversation_id = conversation.id
self.session_conversation_map[session_id] = conversation_id # 更新映射

Expand All @@ -47,4 +52,4 @@ def response(self, session_id, dialogue):
):
if event.event == ChatEventType.CONVERSATION_MESSAGE_DELTA:
print(event.message.content, end="", flush=True)
yield event.message.content
yield event.message.content
14 changes: 9 additions & 5 deletions main/xiaozhi-server/core/providers/tts/fishspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing_extensions import Annotated
from datetime import datetime
from typing import Literal
from core.utils.util import check_model_key
from core.utils.util import check_model_key, parse_string_to_list
from core.providers.tts.base import TTSProviderBase
from config.logger import setup_logging

Expand Down Expand Up @@ -86,8 +86,8 @@ def __init__(self, config, delete_audio_file):
super().__init__(config, delete_audio_file)

self.reference_id = config.get("reference_id")
self.reference_audio = config.get("reference_audio", [])
self.reference_text = config.get("reference_text", [])
self.reference_audio = parse_string_to_list(config.get("reference_audio"))
self.reference_text = parse_string_to_list(config.get("reference_text"))
self.format = config.get("format", "wav")
self.channels = int(config.get("channels", 1))
self.rate = int(config.get("rate", 44100))
Expand All @@ -101,9 +101,13 @@ def __init__(self, config, delete_audio_file):
self.top_p = float(config.get("top_p", 0.7))
self.repetition_penalty = float(config.get("repetition_penalty", 1.2))
self.temperature = float(config.get("temperature", 0.7))
self.streaming = bool(config.get("streaming", False))
self.streaming = str(config.get("streaming", False)).lower() in (
"true",
"1",
"yes",
)
self.use_memory_cache = config.get("use_memory_cache", "on")
self.seed = config.get("seed")
self.seed = config.get("seed") or None
self.api_url = config.get("api_url", "http://127.0.0.1:8080/v1/tts")

def generate_filename(self, extension=".wav"):
Expand Down
30 changes: 25 additions & 5 deletions main/xiaozhi-server/core/providers/tts/gpt_sovits_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from config.logger import setup_logging
from datetime import datetime
from core.providers.tts.base import TTSProviderBase
from core.utils.util import parse_string_to_list

TAG = __name__
logger = setup_logging()
Expand All @@ -25,14 +26,33 @@ def __init__(self, config, delete_audio_file):
self.text_split_method = config.get("text_split_method", "cut0")
self.batch_size = int(config.get("batch_size", 1))
self.batch_threshold = float(config.get("batch_threshold", 0.75))
self.split_bucket = bool(config.get("split_bucket", True))
self.return_fragment = bool(config.get("return_fragment", False))

self.split_bucket = str(config.get("split_bucket", True)).lower() in (
"true",
"1",
"yes",
)
self.return_fragment = str(config.get("return_fragment", False)).lower() in (
"true",
"1",
"yes",
)
self.speed_factor = float(config.get("speed_factor", 1.0))
self.streaming_mode = bool(config.get("streaming_mode", False))
self.streaming_mode = str(config.get("streaming_mode", False)).lower() in (
"true",
"1",
"yes",
)
self.seed = int(config.get("seed", -1))
self.parallel_infer = bool(config.get("parallel_infer", True))
self.parallel_infer = str(config.get("parallel_infer", True)).lower() in (
"true",
"1",
"yes",
)
self.repetition_penalty = float(config.get("repetition_penalty", 1.35))
self.aux_ref_audio_paths = config.get("aux_ref_audio_paths", [])
self.aux_ref_audio_paths = parse_string_to_list(
config.get("aux_ref_audio_paths")
)

def generate_filename(self, extension=".wav"):
return os.path.join(
Expand Down
5 changes: 3 additions & 2 deletions main/xiaozhi-server/core/providers/tts/gpt_sovits_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from config.logger import setup_logging
from datetime import datetime
from core.providers.tts.base import TTSProviderBase
from core.utils.util import parse_string_to_list

TAG = __name__
logger = setup_logging()
Expand All @@ -22,9 +23,9 @@ def __init__(self, config, delete_audio_file):
self.temperature = float(config.get("temperature", 1.0))
self.cut_punc = config.get("cut_punc", "")
self.speed = float(config.get("speed", 1.0))
self.inp_refs = config.get("inp_refs", [])
self.inp_refs = parse_string_to_list(config.get("inp_refs"))
self.sample_steps = int(config.get("sample_steps", 32))
self.if_sr = bool(config.get("if_sr", False))
self.if_sr = str(config.get("if_sr", False)).lower() in ("true", "1", "yes")

def generate_filename(self, extension=".wav"):
return os.path.join(
Expand Down
3 changes: 2 additions & 1 deletion main/xiaozhi-server/core/providers/tts/minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import requests
from datetime import datetime
from core.providers.tts.base import TTSProviderBase
from core.utils.util import parse_string_to_list


class TTSProvider(TTSProviderBase):
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(self, config, delete_audio_file):
**config.get("pronunciation_dict", {}),
}
self.audio_setting = {**defult_audio_setting, **config.get("audio_setting", {})}
self.timber_weights = config.get("timber_weights", [])
self.timber_weights = parse_string_to_list(config.get("timber_weights"))

if self.voice_id:
self.voice_setting["voice_id"] = self.voice_id
Expand Down
2 changes: 1 addition & 1 deletion main/xiaozhi-server/core/providers/tts/ttson.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, config, delete_audio_file):
self.to_lang = config.get("to_lang")
self.volume_change_dB = int(config.get("volume_change_dB", 0))
self.speed_factor = int(config.get("speed_factor", 1))
self.stream = bool(config.get("stream", False))
self.stream = str(config.get("stream", False)).lower() in ("true", "1", "yes")
self.output_file = config.get("output_dir")
self.pitch_factor = int(config.get("pitch_factor", 0))
self.format = config.get("format", "mp3")
Expand Down
22 changes: 20 additions & 2 deletions main/xiaozhi-server/core/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,24 @@ def check_model_key(modelType, modelKey):
return True


def parse_string_to_list(value, separator=";"):
"""
将输入值转换为列表
Args:
value: 输入值,可以是 None、字符串或列表
separator: 分隔符,默认为分号
Returns:
list: 处理后的列表
"""
if value is None or value == "":
return []
elif isinstance(value, str):
return [item.strip() for item in value.split(separator) if item.strip()]
elif isinstance(value, list):
return value
return []


def check_ffmpeg_installed():
ffmpeg_installed = False
try:
Expand Down Expand Up @@ -231,7 +249,7 @@ def initialize_modules(
modules["tts"] = tts.create_instance(
tts_type,
config["TTS"][select_tts_module],
bool(config.get("delete_audio", True)),
str(config.get("delete_audio", True)).lower() in ("true", "1", "yes"),
)
logger.bind(tag=TAG).info(f"初始化组件: tts成功 {select_tts_module}")

Expand Down Expand Up @@ -302,7 +320,7 @@ def initialize_modules(
modules["asr"] = asr.create_instance(
asr_type,
config["ASR"][select_asr_module],
bool(config.get("delete_audio", True)),
str(config.get("delete_audio", True)).lower() in ("true", "1", "yes"),
)
logger.bind(tag=TAG).info(f"初始化组件: asr成功 {select_asr_module}")

Expand Down
Loading