diff --git a/cmd/client/main.go b/cmd/client/main.go index 364b9d5..64f5814 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -8,8 +8,10 @@ import ( "log" "os" "os/exec" + "os/signal" "strconv" "strings" + "syscall" ) // Request represents a JSON-RPC request @@ -55,6 +57,9 @@ func main() { // Start the MCP server in a separate process cmd := exec.Command("go", "run", "./cmd/server/main.go") + // Set process group ID so we can kill the entire process group + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + // Connect stdin and stdout to the MCP server stdin, err := cmd.StdinPipe() if err != nil { @@ -74,6 +79,50 @@ func main() { log.Fatalf("Failed to start server: %v", err) } + // Ensure server process is always terminated when client exits + defer func() { + if cmd.Process != nil { + // Kill the entire process group to ensure all child processes are terminated + pgid, err := syscall.Getpgid(cmd.Process.Pid) + if err == nil { + // Kill the process group (negative PID kills the process group) + if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil { + log.Printf("Failed to kill process group: %v", err) + // Fallback to killing just the main process + cmd.Process.Kill() + } else { + log.Println("Server process group terminated") + } + } else { + // Fallback to killing just the main process + if err := cmd.Process.Kill(); err != nil { + log.Printf("Failed to kill server process: %v", err) + } else { + log.Println("Server process terminated") + } + } + cmd.Wait() // Wait for the process to actually exit + } + }() + + // Set up signal handling to gracefully shutdown server on interruption + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + go func() { + <-sigChan + log.Println("Received interrupt signal, shutting down...") + if cmd.Process != nil { + // Kill the entire process group + pgid, err := syscall.Getpgid(cmd.Process.Pid) + if err == nil { + syscall.Kill(-pgid, syscall.SIGTERM) + } else { + cmd.Process.Kill() + } + } + os.Exit(0) + }() + // Create a reader for the stdout reader := bufio.NewReader(stdout) @@ -162,7 +211,8 @@ func main() { // Marshal the request to JSON reqJSON, err := json.Marshal(req) if err != nil { - log.Fatalf("Failed to marshal request: %v", err) + log.Printf("Failed to marshal request: %v", err) + return } fmt.Printf("Sending request: %s\n", string(reqJSON)) @@ -170,13 +220,15 @@ func main() { // Send the request to the server _, err = stdin.Write(append(reqJSON, '\n')) if err != nil { - log.Fatalf("Failed to send request: %v", err) + log.Printf("Failed to send request: %v", err) + return } // Read the response from the server respJSON, err := reader.ReadBytes('\n') if err != nil && err != io.EOF { - log.Fatalf("Failed to read response: %v", err) + log.Printf("Failed to read response: %v", err) + return } fmt.Printf("Received response: %s\n", string(respJSON)) @@ -184,18 +236,21 @@ func main() { // Unmarshal the response var resp Response if err := json.Unmarshal(respJSON, &resp); err != nil { - log.Fatalf("Failed to unmarshal response: %v", err) + log.Printf("Failed to unmarshal response: %v", err) + return } // Check for errors if resp.Error != nil { - log.Fatalf("Error from server: %s", resp.Error.Message) + log.Printf("Error from server: %s", resp.Error.Message) + return } // Unmarshal the result var result ToolResult if err := json.Unmarshal(resp.Result, &result); err != nil { - log.Fatalf("Failed to unmarshal result: %v", err) + log.Printf("Failed to unmarshal result: %v", err) + return } // Process the result @@ -207,11 +262,6 @@ func main() { fmt.Println(item.Text) } } - - // Terminate the server - if err := cmd.Process.Kill(); err != nil { - log.Printf("Failed to kill server process: %v", err) - } } func showUsage() {