Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion examples/rig-integration/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,11 @@ anyhow = "1.0"
serde_json = "1"
serde = { version = "1", features = ["derive"] }
toml = "0.8"
futures = "0.3"
futures = "0.3"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = [
"env-filter",
"std",
"fmt",
] }
tracing-appender = "0.2"
6 changes: 3 additions & 3 deletions examples/rig-integration/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ use futures::StreamExt;
use rig::{
agent::Agent,
message::Message,
providers::deepseek::DeepSeekCompletionModel,
streaming::{StreamingChat, StreamingChoice},
streaming::{StreamingChat, StreamingChoice, StreamingCompletionModel},
};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};

pub async fn cli_chatbot(chatbot: Agent<DeepSeekCompletionModel>) -> anyhow::Result<()> {
pub async fn cli_chatbot<M: StreamingCompletionModel>(chatbot: Agent<M>) -> anyhow::Result<()> {
let mut chat_log = vec![];

let mut output = BufWriter::new(tokio::io::stdout());
Expand All @@ -27,6 +26,7 @@ pub async fn cli_chatbot(chatbot: Agent<DeepSeekCompletionModel>) -> anyhow::Res
}
match chatbot.stream_chat(input, chat_log.clone()).await {
Ok(mut response) => {
tracing::info!(%input);
chat_log.push(Message::user(input));
stream_output_agent_start(&mut output).await?;
let mut message_buf = String::new();
Expand Down
42 changes: 23 additions & 19 deletions examples/rig-integration/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
use rig::{
embeddings::EmbeddingsBuilder,
providers::{
cohere,
deepseek::{self, DEEPSEEK_CHAT},
},
providers::{cohere, deepseek},
vector_store::in_memory_store::InMemoryVectorStore,
};
use tracing_appender::rolling::{RollingFileAppender, Rotation};
pub mod chat;
pub mod config;
pub mod mcp_adaptor;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
let file_appender = RollingFileAppender::new(
Rotation::DAILY,
"logs",
format!("{}.log", env!("CARGO_CRATE_NAME")),
);
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::from_default_env()
.add_directive(tracing::Level::INFO.into()),
)
.with_writer(file_appender)
.with_file(false)
.with_ansi(false)
.init();

let config = config::Config::retrieve("config.toml").await?;
let openai_client = {
if let Some(key) = config.deepseek_key {
Expand All @@ -27,7 +41,7 @@ async fn main() -> anyhow::Result<()> {
}
};
let mcp_manager = config.mcp.create_manager().await?;
eprintln!(
tracing::info!(
"MCP Manager created, {} servers started",
mcp_manager.clients.len()
);
Expand All @@ -39,26 +53,16 @@ async fn main() -> anyhow::Result<()> {
.build()
.await?;
let store = InMemoryVectorStore::from_documents_with_id_f(embeddings, |f| {
eprintln!("store tool {}", f.name);
tracing::info!("store tool {}", f.name);
f.name.clone()
});
let index = store.index(embedding_model);
let dpsk = openai_client
.agent(DEEPSEEK_CHAT)
.context(
r#"You are an assistant here to help the user to do some works.
When you need to use tools, you should select which tool is most appropriate to meet the user's requirement.
Follow these instructions closely.
1. Consider the user's request carefully and identify the core elements of the request.
2. Select which tool among those made available to you is appropriate given the context.
3. This is very important: never perform the operation yourself and never give me the direct result.
Always respond with the name of the tool that should be used and the appropriate inputs
in the following format:
Tool: <tool name>
Inputs: <list of inputs>"#,
)
.agent(deepseek::DEEPSEEK_CHAT)
.dynamic_tools(4, index, tool_set)
.build();

chat::cli_chatbot(dpsk).await?;

Ok(())
}
29 changes: 24 additions & 5 deletions examples/rig-integration/src/mcp_adaptor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashMap;

use rig::tool::{ToolDyn as RigTool, ToolSet};
use rig::tool::{ToolDyn as RigTool, ToolEmbeddingDyn, ToolSet};
use rmcp::{
RoleClient,
model::{CallToolRequestParam, CallToolResult, Tool as McpTool},
Expand Down Expand Up @@ -49,13 +49,31 @@ impl RigTool for McpToolAdaptor {
.map_err(rig::tool::ToolError::JsonError)?,
})
.await
.inspect(|result| tracing::info!(?result))
.inspect_err(|error| tracing::error!(%error))
.map_err(|e| rig::tool::ToolError::ToolCallError(Box::new(e)))?;

Ok(convert_mcp_call_tool_result_to_string(call_mcp_tool_result))
})
}
}

impl ToolEmbeddingDyn for McpToolAdaptor {
fn context(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self.tool.clone())
}

fn embedding_docs(&self) -> Vec<String> {
vec![
self.tool
.description
.as_deref()
.unwrap_or_default()
.to_string(),
]
}
}

pub struct McpManager {
pub clients: HashMap<String, RunningService<RoleClient, ()>>,
}
Expand All @@ -72,7 +90,7 @@ impl McpManager {
for result in results {
match result {
Err(e) => {
eprintln!("Failed to get tool set: {:?}", e);
tracing::error!(error = %e, "Failed to get tool set");
}
Ok(tools) => {
tool_set.add_tools(tools);
Expand All @@ -89,14 +107,15 @@ pub fn convert_mcp_call_tool_result_to_string(result: CallToolResult) -> String

pub async fn get_tool_set(server: ServerSink) -> anyhow::Result<ToolSet> {
let tools = server.list_all_tools().await?;
let mut tool_set = ToolSet::default();
let mut tool_builder = ToolSet::builder();
for tool in tools {
eprintln!("get tool: {}", tool.name);
tracing::info!("get tool: {}", tool.name);
let adaptor = McpToolAdaptor {
tool: tool.clone(),
server: server.clone(),
};
tool_set.add_tool(adaptor);
tool_builder = tool_builder.dynamic_tool(adaptor);
}
let tool_set = tool_builder.build();
Ok(tool_set)
}
Loading