|
23 | 23 | from typing import Any, Literal |
24 | 24 | from unittest.mock import patch |
25 | 25 |
|
| 26 | +import anthropic |
26 | 27 | import cloudpickle |
27 | 28 | import httpx |
28 | 29 | import openai |
@@ -294,6 +295,131 @@ def __exit__(self, exc_type, exc_value, traceback): |
294 | 295 | self.proc.kill() |
295 | 296 |
|
296 | 297 |
|
| 298 | +class RemoteAnthropicServer: |
| 299 | + DUMMY_API_KEY = "token-abc123" # vLLM's Anthropic server does not need API key |
| 300 | + |
| 301 | + def __init__( |
| 302 | + self, |
| 303 | + model: str, |
| 304 | + vllm_serve_args: list[str], |
| 305 | + *, |
| 306 | + env_dict: dict[str, str] | None = None, |
| 307 | + seed: int | None = 0, |
| 308 | + auto_port: bool = True, |
| 309 | + max_wait_seconds: float | None = None, |
| 310 | + ) -> None: |
| 311 | + if auto_port: |
| 312 | + if "-p" in vllm_serve_args or "--port" in vllm_serve_args: |
| 313 | + raise ValueError( |
| 314 | + "You have manually specified the port when `auto_port=True`." |
| 315 | + ) |
| 316 | + |
| 317 | + # Don't mutate the input args |
| 318 | + vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())] |
| 319 | + if seed is not None: |
| 320 | + if "--seed" in vllm_serve_args: |
| 321 | + raise ValueError( |
| 322 | + f"You have manually specified the seed when `seed={seed}`." |
| 323 | + ) |
| 324 | + |
| 325 | + vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] |
| 326 | + |
| 327 | + parser = FlexibleArgumentParser(description="vLLM's remote Anthropic server.") |
| 328 | + subparsers = parser.add_subparsers(required=False, dest="subparser") |
| 329 | + parser = ServeSubcommand().subparser_init(subparsers) |
| 330 | + args = parser.parse_args(["--model", model, *vllm_serve_args]) |
| 331 | + self.host = str(args.host or "localhost") |
| 332 | + self.port = int(args.port) |
| 333 | + |
| 334 | + self.show_hidden_metrics = args.show_hidden_metrics_for_version is not None |
| 335 | + |
| 336 | + # download the model before starting the server to avoid timeout |
| 337 | + is_local = os.path.isdir(model) |
| 338 | + if not is_local: |
| 339 | + engine_args = AsyncEngineArgs.from_cli_args(args) |
| 340 | + model_config = engine_args.create_model_config() |
| 341 | + load_config = engine_args.create_load_config() |
| 342 | + |
| 343 | + model_loader = get_model_loader(load_config) |
| 344 | + model_loader.download_model(model_config) |
| 345 | + |
| 346 | + env = os.environ.copy() |
| 347 | + # the current process might initialize cuda, |
| 348 | + # to be safe, we should use spawn method |
| 349 | + env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" |
| 350 | + if env_dict is not None: |
| 351 | + env.update(env_dict) |
| 352 | + self.proc = subprocess.Popen( |
| 353 | + [ |
| 354 | + sys.executable, |
| 355 | + "-m", |
| 356 | + "vllm.entrypoints.anthropic.api_server", |
| 357 | + model, |
| 358 | + *vllm_serve_args, |
| 359 | + ], |
| 360 | + env=env, |
| 361 | + stdout=sys.stdout, |
| 362 | + stderr=sys.stderr, |
| 363 | + ) |
| 364 | + max_wait_seconds = max_wait_seconds or 240 |
| 365 | + self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds) |
| 366 | + |
| 367 | + def __enter__(self): |
| 368 | + return self |
| 369 | + |
| 370 | + def __exit__(self, exc_type, exc_value, traceback): |
| 371 | + self.proc.terminate() |
| 372 | + try: |
| 373 | + self.proc.wait(8) |
| 374 | + except subprocess.TimeoutExpired: |
| 375 | + # force kill if needed |
| 376 | + self.proc.kill() |
| 377 | + |
| 378 | + def _wait_for_server(self, *, url: str, timeout: float): |
| 379 | + # run health check |
| 380 | + start = time.time() |
| 381 | + while True: |
| 382 | + try: |
| 383 | + if requests.get(url).status_code == 200: |
| 384 | + break |
| 385 | + except Exception: |
| 386 | + # this exception can only be raised by requests.get, |
| 387 | + # which means the server is not ready yet. |
| 388 | + # the stack trace is not useful, so we suppress it |
| 389 | + # by using `raise from None`. |
| 390 | + result = self.proc.poll() |
| 391 | + if result is not None and result != 0: |
| 392 | + raise RuntimeError("Server exited unexpectedly.") from None |
| 393 | + |
| 394 | + time.sleep(0.5) |
| 395 | + if time.time() - start > timeout: |
| 396 | + raise RuntimeError("Server failed to start in time.") from None |
| 397 | + |
| 398 | + @property |
| 399 | + def url_root(self) -> str: |
| 400 | + return f"http://{self.host}:{self.port}" |
| 401 | + |
| 402 | + def url_for(self, *parts: str) -> str: |
| 403 | + return self.url_root + "/" + "/".join(parts) |
| 404 | + |
| 405 | + def get_client(self, **kwargs): |
| 406 | + if "timeout" not in kwargs: |
| 407 | + kwargs["timeout"] = 600 |
| 408 | + return anthropic.Anthropic( |
| 409 | + base_url=self.url_for(), |
| 410 | + api_key=self.DUMMY_API_KEY, |
| 411 | + max_retries=0, |
| 412 | + **kwargs, |
| 413 | + ) |
| 414 | + |
| 415 | + def get_async_client(self, **kwargs): |
| 416 | + if "timeout" not in kwargs: |
| 417 | + kwargs["timeout"] = 600 |
| 418 | + return anthropic.AsyncAnthropic( |
| 419 | + base_url=self.url_for(), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs |
| 420 | + ) |
| 421 | + |
| 422 | + |
297 | 423 | def _test_completion( |
298 | 424 | client: openai.OpenAI, |
299 | 425 | model: str, |
|
0 commit comments