Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 61 additions & 11 deletions cmd/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import (
"log"
"os"
"os/exec"
"os/signal"
"strconv"
"strings"
"syscall"
)

// Request represents a JSON-RPC request
Expand Down Expand Up @@ -55,6 +57,9 @@ func main() {
// Start the MCP server in a separate process
cmd := exec.Command("go", "run", "./cmd/server/main.go")

// Set process group ID so we can kill the entire process group
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}

// Connect stdin and stdout to the MCP server
stdin, err := cmd.StdinPipe()
if err != nil {
Expand All @@ -74,6 +79,50 @@ func main() {
log.Fatalf("Failed to start server: %v", err)
}

// Ensure server process is always terminated when client exits
defer func() {
if cmd.Process != nil {
// Kill the entire process group to ensure all child processes are terminated
pgid, err := syscall.Getpgid(cmd.Process.Pid)
if err == nil {
// Kill the process group (negative PID kills the process group)
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
log.Printf("Failed to kill process group: %v", err)
// Fallback to killing just the main process
cmd.Process.Kill()
} else {
log.Println("Server process group terminated")
}
} else {
// Fallback to killing just the main process
if err := cmd.Process.Kill(); err != nil {
log.Printf("Failed to kill server process: %v", err)
} else {
log.Println("Server process terminated")
}
}
cmd.Wait() // Wait for the process to actually exit
}
}()

// Set up signal handling to gracefully shutdown server on interruption
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
go func() {
<-sigChan
log.Println("Received interrupt signal, shutting down...")
if cmd.Process != nil {
// Kill the entire process group
pgid, err := syscall.Getpgid(cmd.Process.Pid)
if err == nil {
syscall.Kill(-pgid, syscall.SIGTERM)
} else {
cmd.Process.Kill()
}
}
os.Exit(0)
}()

// Create a reader for the stdout
reader := bufio.NewReader(stdout)

Expand Down Expand Up @@ -162,40 +211,46 @@ func main() {
// Marshal the request to JSON
reqJSON, err := json.Marshal(req)
if err != nil {
log.Fatalf("Failed to marshal request: %v", err)
log.Printf("Failed to marshal request: %v", err)
return
}

fmt.Printf("Sending request: %s\n", string(reqJSON))

// Send the request to the server
_, err = stdin.Write(append(reqJSON, '\n'))
if err != nil {
log.Fatalf("Failed to send request: %v", err)
log.Printf("Failed to send request: %v", err)
return
}

// Read the response from the server
respJSON, err := reader.ReadBytes('\n')
if err != nil && err != io.EOF {
log.Fatalf("Failed to read response: %v", err)
log.Printf("Failed to read response: %v", err)
return
}

fmt.Printf("Received response: %s\n", string(respJSON))

// Unmarshal the response
var resp Response
if err := json.Unmarshal(respJSON, &resp); err != nil {
log.Fatalf("Failed to unmarshal response: %v", err)
log.Printf("Failed to unmarshal response: %v", err)
return
}

// Check for errors
if resp.Error != nil {
log.Fatalf("Error from server: %s", resp.Error.Message)
log.Printf("Error from server: %s", resp.Error.Message)
return
}

// Unmarshal the result
var result ToolResult
if err := json.Unmarshal(resp.Result, &result); err != nil {
log.Fatalf("Failed to unmarshal result: %v", err)
log.Printf("Failed to unmarshal result: %v", err)
return
}

// Process the result
Expand All @@ -207,11 +262,6 @@ func main() {
fmt.Println(item.Text)
}
}

// Terminate the server
if err := cmd.Process.Kill(); err != nil {
log.Printf("Failed to kill server process: %v", err)
}
}

func showUsage() {
Expand Down