From e5a1fa840b66e7628ef1a3ef19629026c2dd98ee Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 3 Mar 2023 04:16:04 +0000 Subject: [PATCH 01/15] Fix a bug in 1D shape --- cacheflow/models/attention.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 34edeec02cbc..8b2fb3f85b2a 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -47,9 +47,8 @@ def multi_query_kv_attention( max_s=max_prompt_len, causal=True, )[0] - num_tokens = prefix_sum[-1] # FIXME(woosuk): Unnecessary copy. Optimize this. - output[:num_tokens].copy_(out, non_blocking=True) + output.copy_(out, non_blocking=True) def single_query_cached_kv_attention( self, @@ -108,8 +107,13 @@ def forward( # Compute the attention op for prompts. if input_metadata.num_prompts > 0: + num_prompt_tokens = sum(input_metadata.prompt_lens) self.multi_query_kv_attention( - output, query, key, value, input_metadata.prompt_lens) + output[:num_prompt_tokens], + query[:num_prompt_tokens], + key[:num_prompt_tokens], + value[:num_prompt_tokens], + input_metadata.prompt_lens) # Wait until the cache op is done. if cache_event is not None: From 342275fdcd6bc2ba5332335bdb0a53e46b2011e0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 3 Mar 2023 04:16:15 +0000 Subject: [PATCH 02/15] Minor --- cacheflow/models/input_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py index 86cc2e8f1f5a..77f25054e38a 100644 --- a/cacheflow/models/input_metadata.py +++ b/cacheflow/models/input_metadata.py @@ -24,7 +24,7 @@ def __init__( self.num_prompts = len(prompt_lens) self.num_generation_tokens = context_lens.shape[0] - self.num_valid_tokens = len(slot_mapping) + self.num_valid_tokens = slot_mapping.shape[0] if block_tables.numel() > 0: self.max_num_blocks_per_seq = block_tables.shape[1] else: From b91a2fada7090cd6b7cdf9d1f26ef1e32f2737b7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 3 Mar 2023 04:19:52 +0000 Subject: [PATCH 03/15] Minor --- cacheflow/models/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 8b2fb3f85b2a..7c77db5a819b 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -113,7 +113,8 @@ def forward( query[:num_prompt_tokens], key[:num_prompt_tokens], value[:num_prompt_tokens], - input_metadata.prompt_lens) + input_metadata.prompt_lens, + ) # Wait until the cache op is done. if cache_event is not None: From d78e2fb637472dd55be51b8766ca132e6ded61e9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 3 Mar 2023 09:32:17 +0000 Subject: [PATCH 04/15] [WIP] Add memory analyzer --- cacheflow/models/__init__.py | 6 ++- cacheflow/models/memory_analyzer.py | 58 +++++++++++++++++++++++++++++ cacheflow/models/model_utils.py | 13 +++++-- 3 files changed, 72 insertions(+), 5 deletions(-) create mode 100644 cacheflow/models/memory_analyzer.py diff --git a/cacheflow/models/__init__.py b/cacheflow/models/__init__.py index 498101b53fdd..b42d06e74341 100644 --- a/cacheflow/models/__init__.py +++ b/cacheflow/models/__init__.py @@ -1,8 +1,12 @@ from cacheflow.models.input_metadata import InputMetadata +from cacheflow.models.memory_analyzer import compute_max_num_cpu_blocks +from cacheflow.models.memory_analyzer import compute_max_num_gpu_blocks from cacheflow.models.model_utils import get_model __all__ = [ - 'get_model', 'InputMetadata', + 'compute_max_num_cpu_blocks', + 'compute_max_num_gpu_blocks', + 'get_model', ] diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py new file mode 100644 index 000000000000..e4751a6949b4 --- /dev/null +++ b/cacheflow/models/memory_analyzer.py @@ -0,0 +1,58 @@ +from typing import Union + +import torch +from transformers import AutoConfig + +from cacheflow.models.model_utils import get_torch_dtype + +GB = 1 << 30 + + +def compute_max_num_gpu_blocks( + model_name: str, + max_num_batched_tokens: int, + block_size: int, + dtype: Union[torch.dtype, str], +) -> int: + torch_dtype = get_torch_dtype(dtype) + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + + # FIXME(woosuk) + config = AutoConfig.from_pretrained(model_name) + num_layers = config.num_hidden_layers + num_heads = config.num_attention_heads + hidden_size = config.hidden_size + vocab_size = config.vocab_size + + total_memory = torch.cuda.get_device_properties(0).total_memory + total_memory = int(0.975 * total_memory) + + param_size = (num_layers * 12 * hidden_size * hidden_size * dtype_size + + vocab_size * hidden_size * dtype_size) + mha_act_size = num_heads * max_num_batched_tokens * max_num_batched_tokens * dtype_size + ffn_act_size = 4 * hidden_size * max_num_batched_tokens * dtype_size + # Conservative estimate of the peak activation size. + act_size = 3 * max(mha_act_size, ffn_act_size) + workspace_size = 1 * GB + + max_cache_size = total_memory - (param_size + act_size + workspace_size) + max_num_blocks = max_cache_size // (num_layers * 2 * block_size * hidden_size * dtype_size) + return max_num_blocks + + +def compute_max_num_cpu_blocks( + swap_space: int, + model_name: str, + block_size: int, + dtype: Union[torch.dtype, str], +) -> int: + torch_dtype = get_torch_dtype(dtype) + dtype_size = torch.tensor([], dtype=torch_dtype).element_size() + + config = AutoConfig.from_pretrained(model_name) + num_layers = config.num_hidden_layers + hidden_size = config.hidden_size + + max_cache_size = swap_space * GB + max_num_blocks = max_cache_size // (num_layers * 2 * block_size * hidden_size * dtype_size) + return max_num_blocks diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index 0e7e4d3b2dd0..560b8eb84be8 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -17,14 +17,19 @@ } -def get_model( - model_name: str, - dtype: Union[torch.dtype, str], -) -> nn.Module: +def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: if isinstance(dtype, str): torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()] else: torch_dtype = dtype + return torch_dtype + + +def get_model( + model_name: str, + dtype: Union[torch.dtype, str], +) -> nn.Module: + torch_dtype = get_torch_dtype(dtype) for model_class, hf_model in MODEL_CLASSES.items(): if model_class in model_name: model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype) From 2649eb5befefe424c7a1bcc9415036823b299aca Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 3 Mar 2023 09:34:49 +0000 Subject: [PATCH 05/15] Automatically config GPU/CPU blocks --- server.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/server.py b/server.py index d70dab01abd4..368e51ce4862 100644 --- a/server.py +++ b/server.py @@ -3,6 +3,8 @@ from cacheflow.master.frontend import Frontend from cacheflow.master.scheduler import Scheduler +from cacheflow.models import compute_max_num_cpu_blocks +from cacheflow.models import compute_max_num_gpu_blocks from cacheflow.worker.controller import Controller parser = argparse.ArgumentParser(description='CacheFlow server') @@ -11,14 +13,28 @@ parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node') parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size') # TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks. -parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks (per GPU)') -parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)') +parser.add_argument('--swap-space', type=int, default=16, + help='The CPU memory space in GiB pinned for swapping (per GPU)') # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') args = parser.parse_args() def main(): + num_gpu_blocks = compute_max_num_gpu_blocks( + model_name=args.model, + max_num_batched_tokens=2048, + block_size=args.block_size, + dtype=args.dtype, + ) + num_cpu_blocks = compute_max_num_cpu_blocks( + swap_space=args.swap_space, + model_name=args.model, + block_size=args.block_size, + dtype=args.dtype, + ) + print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}') + # Create a controller for each node. controllers: List[Controller] = [] for i in range(args.num_nodes): @@ -27,8 +43,8 @@ def main(): num_workers=args.num_workers, model_name=args.model, block_size=args.block_size, - num_gpu_blocks=args.num_gpu_blocks, - num_cpu_blocks=args.num_cpu_blocks, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, dtype=args.dtype, ) controllers.append(controller) @@ -44,8 +60,8 @@ def main(): frontend=frontend, controllers=controllers, block_size=args.block_size, - num_gpu_blocks=args.num_gpu_blocks, - num_cpu_blocks=args.num_cpu_blocks, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, ) # Connect the controllers. for i in range(len(controllers) - 1): From 1ae7420f20cf6c859989de6aa6f3892e2863543c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 3 Mar 2023 09:35:21 +0000 Subject: [PATCH 06/15] Remove TODO --- server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/server.py b/server.py index 368e51ce4862..477da5b3ff97 100644 --- a/server.py +++ b/server.py @@ -12,7 +12,6 @@ parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes') parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node') parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size') -# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks. parser.add_argument('--swap-space', type=int, default=16, help='The CPU memory space in GiB pinned for swapping (per GPU)') # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. From 350ed273bfec55a3127f3e4dc37ddd2f79f992e1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 12 Mar 2023 05:37:20 +0000 Subject: [PATCH 07/15] Add max_num_batched_tokens argument --- cacheflow/master/scheduler.py | 6 +++--- server.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 7f2ca1455fc4..0d1b8f9c36ab 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -9,8 +9,6 @@ from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceStatus -_MAX_NUM_BATCHED_TOKENS = 2048 - class Scheduler: @@ -21,12 +19,14 @@ def __init__( block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, + max_num_batched_tokens: int, ) -> None: self.frontend = frontend self.controllers = controllers self.block_size = block_size self.num_gpu_blocks = num_gpu_blocks self.num_cpu_blocks = num_cpu_blocks + self.max_num_batched_tokens = max_num_batched_tokens # Create the block space manager. self.block_manager = BlockSpaceManager( @@ -164,7 +164,7 @@ def step(self) -> None: num_prompt_tokens = seq_group.seqs[0].get_len() if self.block_manager.can_allocate(seq_group): if (num_batched_tokens + num_prompt_tokens - <= _MAX_NUM_BATCHED_TOKENS): + <= self.max_num_batched_tokens): self._allocate(seq_group) num_batched_tokens += num_prompt_tokens continue diff --git a/server.py b/server.py index 483441b82a09..1246880f995d 100644 --- a/server.py +++ b/server.py @@ -18,6 +18,7 @@ parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). parser.add_argument('--seed', type=int, default=0, help='random seed') +parser.add_argument('--max-batch-size', type=int, default=2048, help='maximum number of batched tokens') args = parser.parse_args() @@ -64,6 +65,7 @@ def main(): block_size=args.block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, + max_num_batched_tokens=args.max_batch_size, ) # Connect the controllers. for i in range(len(controllers) - 1): From 6f5b41bb27ae1c15d482320d356e39d9db6221fa Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 12 Mar 2023 05:47:29 +0000 Subject: [PATCH 08/15] Minor --- cacheflow/models/model_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index 642ee23d4cce..61f657313759 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -7,11 +7,11 @@ from cacheflow.models.opt import OPTForCausalLM -MODEL_CLASSES = { +_MODEL_CLASSES = { 'opt': OPTForCausalLM, } -STR_DTYPE_TO_TORCH_DTYPE = { +_STR_DTYPE_TO_TORCH_DTYPE = { 'half': torch.half, 'float': torch.float, 'float16': torch.float16, @@ -21,7 +21,7 @@ def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: if isinstance(dtype, str): - torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()] + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()] else: torch_dtype = dtype return torch_dtype @@ -32,9 +32,10 @@ def get_model( dtype: Union[torch.dtype, str], ) -> nn.Module: torch_dtype = get_torch_dtype(dtype) - for model_class, hf_model in MODEL_CLASSES.items(): + for model_class, hf_model in _MODEL_CLASSES.items(): if model_class in model_name: - model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype) + model = hf_model.from_pretrained( + model_name, torch_dtype=torch_dtype) return model.eval() raise ValueError(f'Invalid model name: {model_name}') From 2d03918d2fc234f2f8131521ff8ed44323bef484 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 12 Mar 2023 06:42:30 +0000 Subject: [PATCH 09/15] Minor --- cacheflow/models/model_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index 61f657313759..f73352bb45f5 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -37,7 +37,7 @@ def get_model( model = hf_model.from_pretrained( model_name, torch_dtype=torch_dtype) return model.eval() - raise ValueError(f'Invalid model name: {model_name}') + raise ValueError(f'Unsupported model name: {model_name}') def set_seed(seed: int) -> None: From 8ec00fe41b779e0d0ec247080e7afd53c445faea Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 12 Mar 2023 07:11:23 +0000 Subject: [PATCH 10/15] Refactor model utils --- cacheflow/models/utils.py | 43 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 cacheflow/models/utils.py diff --git a/cacheflow/models/utils.py b/cacheflow/models/utils.py new file mode 100644 index 000000000000..4b705bf7d969 --- /dev/null +++ b/cacheflow/models/utils.py @@ -0,0 +1,43 @@ +from typing import Union + +import random + +import numpy as np +import psutil +import torch + +_STR_DTYPE_TO_TORCH_DTYPE = { + 'half': torch.half, + 'float': torch.float, + 'float16': torch.float16, + 'float32': torch.float32, +} + + +def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: + if isinstance(dtype, str): + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()] + else: + torch_dtype = dtype + return torch_dtype + + +def get_dtype_size(dtype: Union[torch.dtype, str]) -> int: + torch_dtype = get_torch_dtype(dtype) + return torch.tensor([], dtype=torch_dtype).element_size() + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def get_gpu_memory(gpu: int = 0) -> int: + return torch.cuda.get_device_properties(gpu).total_memory + + +def get_cpu_memory() -> int: + return psutil.virtual_memory().total From 84203fc325facd30bc456f333113b98209af6563 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 12 Mar 2023 07:11:43 +0000 Subject: [PATCH 11/15] Re-implement memory analyzer --- cacheflow/models/memory_analyzer.py | 177 +++++++++++++++++++--------- cacheflow/models/model_utils.py | 42 +++---- 2 files changed, 141 insertions(+), 78 deletions(-) diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index e4751a6949b4..ed8c69a0e0ae 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -1,58 +1,125 @@ -from typing import Union - import torch from transformers import AutoConfig -from cacheflow.models.model_utils import get_torch_dtype - -GB = 1 << 30 - - -def compute_max_num_gpu_blocks( - model_name: str, - max_num_batched_tokens: int, - block_size: int, - dtype: Union[torch.dtype, str], -) -> int: - torch_dtype = get_torch_dtype(dtype) - dtype_size = torch.tensor([], dtype=torch_dtype).element_size() - - # FIXME(woosuk) - config = AutoConfig.from_pretrained(model_name) - num_layers = config.num_hidden_layers - num_heads = config.num_attention_heads - hidden_size = config.hidden_size - vocab_size = config.vocab_size - - total_memory = torch.cuda.get_device_properties(0).total_memory - total_memory = int(0.975 * total_memory) - - param_size = (num_layers * 12 * hidden_size * hidden_size * dtype_size - + vocab_size * hidden_size * dtype_size) - mha_act_size = num_heads * max_num_batched_tokens * max_num_batched_tokens * dtype_size - ffn_act_size = 4 * hidden_size * max_num_batched_tokens * dtype_size - # Conservative estimate of the peak activation size. - act_size = 3 * max(mha_act_size, ffn_act_size) - workspace_size = 1 * GB - - max_cache_size = total_memory - (param_size + act_size + workspace_size) - max_num_blocks = max_cache_size // (num_layers * 2 * block_size * hidden_size * dtype_size) - return max_num_blocks - - -def compute_max_num_cpu_blocks( - swap_space: int, - model_name: str, - block_size: int, - dtype: Union[torch.dtype, str], -) -> int: - torch_dtype = get_torch_dtype(dtype) - dtype_size = torch.tensor([], dtype=torch_dtype).element_size() - - config = AutoConfig.from_pretrained(model_name) - num_layers = config.num_hidden_layers - hidden_size = config.hidden_size - - max_cache_size = swap_space * GB - max_num_blocks = max_cache_size // (num_layers * 2 * block_size * hidden_size * dtype_size) - return max_num_blocks +from cacheflow.models.utils import get_cpu_memory +from cacheflow.models.utils import get_dtype_size +from cacheflow.models.utils import get_gpu_memory + +_GiB = 1 << 30 + + +class CacheFlowMemoryAnalyzer: + + def get_max_num_gpu_blocks( + self, + max_num_batched_tokens: int, + memory_utilization: float, + ) -> int: + raise NotImplementedError() + + def get_max_num_cpu_blocks( + self, + memory_utilization: float, + ) -> int: + raise NotImplementedError() + + +class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer): + + def __init__( + self, + model_name: str, + block_size: int, + dtype: torch.dtype, + ) -> None: + self.model_name = model_name + self.block_size = block_size + self.dtype = dtype + + # TODO(woosuk): Support tensor parallelism. + config = AutoConfig.from_pretrained(model_name) + self.num_layers = config.num_hidden_layers + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = config.hidden_size // self.num_heads + self.ffn_size = config.ffn_dim + self.embedding_size = config.word_embed_proj_dim + self.vocab_size = config.vocab_size + self.max_position = config.max_position_embeddings + + def _get_param_size(self) -> int: + # TODO(woosuk): Support tensor parallelism. + word_embedding = self.vocab_size * self.embedding_size + if self.embedding_size != self.vocab_size: + # Project in/out. + word_embedding += 2 * self.embedding_size * self.vocab_size + position_embedding = self.max_position * self.hidden_size + + ln1 = 2 * self.hidden_size + q = self.hidden_size * self.hidden_size + self.hidden_size + k = self.hidden_size * self.hidden_size + self.hidden_size + v = self.hidden_size * self.hidden_size + self.hidden_size + out = self.hidden_size * self.hidden_size + self.hidden_size + mha = ln1 + q + k + v + out + + ln2 = 2 * self.hidden_size + ffn1 = self.hidden_size * self.ffn_size + self.ffn_size + ffn2 = self.ffn_size * self.hidden_size + self.hidden_size + ffn = ln2 + ffn1 + ffn2 + + total = (word_embedding + position_embedding + + self.num_layers * (mha + ffn)) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * total + + def _get_max_act_size( + self, + max_num_batched_tokens: int, + ) -> int: + # TODO(woosuk): Support tensor parallelism. + # NOTE: We approxmiately calculate the maximum activation size by + # 1) finding a maximum activation tensor size, and + # 2) multiplying it by 4. + # Here, we assume FlashAttention is used, and thus the attention maps + # are not materialized in GPU DRAM. + qkv = 3 * (max_num_batched_tokens * self.hidden_size) + ffn = max_num_batched_tokens * self.ffn_size + max_act = 4 * max(qkv, ffn) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * max_act + + def _get_workspace_size(self) -> int: + return 1 * _GiB + + def _get_cache_block_size(self) -> int: + key_cache_block = self.block_size * self.num_heads * self.head_size + value_cache_block = self.block_size * self.num_heads * self.head_size + total = self.num_layers * (key_cache_block + value_cache_block) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * total + + def get_max_num_gpu_blocks( + self, + max_num_batched_tokens: int, + memory_utilization: float = 0.95, + ) -> int: + # NOTE(woosuk): This assumes that the machine has homogeneous GPUs. + gpu_memory = get_gpu_memory() + usable_memory = int(memory_utilization * gpu_memory) + + param_size = self._get_param_size() + act_size = self._get_max_act_size(max_num_batched_tokens) + workspace_size = self._get_workspace_size() + + max_cache_size = usable_memory - (param_size + act_size + workspace_size) + max_num_blocks = max_cache_size // self._get_cache_block_size() + return max_num_blocks + + def get_max_num_cpu_blocks( + self, + memory_utilization: float = 0.25, + ) -> int: + cpu_memory = get_cpu_memory() + usable_memory = int(memory_utilization * cpu_memory) + max_num_blocks = usable_memory // self._get_cache_block_size() + return max_num_blocks diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index f73352bb45f5..98ff6d44ebb0 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -1,38 +1,29 @@ -import random from typing import Union -import numpy as np import torch import torch.nn as nn +from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer +from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer from cacheflow.models.opt import OPTForCausalLM +from cacheflow.models.utils import get_torch_dtype -_MODEL_CLASSES = { + +_MODELS = { 'opt': OPTForCausalLM, } -_STR_DTYPE_TO_TORCH_DTYPE = { - 'half': torch.half, - 'float': torch.float, - 'float16': torch.float16, - 'float32': torch.float32, +_MEMORY_ANALYZERS = { + 'opt': OPTMemoryAnalyzer, } -def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: - if isinstance(dtype, str): - torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()] - else: - torch_dtype = dtype - return torch_dtype - - def get_model( model_name: str, dtype: Union[torch.dtype, str], ) -> nn.Module: torch_dtype = get_torch_dtype(dtype) - for model_class, hf_model in _MODEL_CLASSES.items(): + for model_class, hf_model in _MODELS.items(): if model_class in model_name: model = hf_model.from_pretrained( model_name, torch_dtype=torch_dtype) @@ -40,9 +31,14 @@ def get_model( raise ValueError(f'Unsupported model name: {model_name}') -def set_seed(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) +def get_memory_analyzer( + model_name: str, + block_size: int, + dtype: Union[torch.dtype, str], +) -> CacheFlowMemoryAnalyzer: + torch_dtype = get_torch_dtype(dtype) + for model_class, memory_analyzer in _MEMORY_ANALYZERS.items(): + if model_class in model_name: + return memory_analyzer( + model_name, block_size, torch_dtype) + raise ValueError(f'Unsupported model name: {model_name}') From 96b216cf20a85ddc487c0be54c8b1f2fbea98d42 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 12 Mar 2023 07:11:56 +0000 Subject: [PATCH 12/15] Fix __init__ --- cacheflow/models/__init__.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/cacheflow/models/__init__.py b/cacheflow/models/__init__.py index 52ce7f374e5c..cd8f134a5a74 100644 --- a/cacheflow/models/__init__.py +++ b/cacheflow/models/__init__.py @@ -1,14 +1,12 @@ from cacheflow.models.input_metadata import InputMetadata -from cacheflow.models.memory_analyzer import compute_max_num_cpu_blocks -from cacheflow.models.memory_analyzer import compute_max_num_gpu_blocks +from cacheflow.models.model_utils import get_memory_analyzer from cacheflow.models.model_utils import get_model -from cacheflow.models.model_utils import set_seed +from cacheflow.models.utils import set_seed __all__ = [ 'InputMetadata', - 'compute_max_num_cpu_blocks', - 'compute_max_num_gpu_blocks', + 'get_memory_analyzer', 'get_model', 'set_seed', ] From c89d44029b715062d1e890f4db0da83c3d52842e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 12 Mar 2023 07:12:27 +0000 Subject: [PATCH 13/15] Use memory analyzer in server.py --- server.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/server.py b/server.py index 1246880f995d..b740724c373f 100644 --- a/server.py +++ b/server.py @@ -3,8 +3,7 @@ from cacheflow.master.frontend import Frontend from cacheflow.master.scheduler import Scheduler -from cacheflow.models import compute_max_num_cpu_blocks -from cacheflow.models import compute_max_num_gpu_blocks +from cacheflow.models import get_memory_analyzer from cacheflow.worker.controller import Controller parser = argparse.ArgumentParser(description='CacheFlow server') @@ -12,8 +11,6 @@ parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes') parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node') parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size') -parser.add_argument('--swap-space', type=int, default=16, - help='The CPU memory space in GiB pinned for swapping (per GPU)') # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). @@ -23,18 +20,14 @@ def main(): - num_gpu_blocks = compute_max_num_gpu_blocks( - model_name=args.model, - max_num_batched_tokens=2048, - block_size=args.block_size, - dtype=args.dtype, - ) - num_cpu_blocks = compute_max_num_cpu_blocks( - swap_space=args.swap_space, + memory_analyzer = get_memory_analyzer( model_name=args.model, block_size=args.block_size, dtype=args.dtype, ) + num_gpu_blocks = memory_analyzer.get_max_num_gpu_blocks( + max_num_batched_tokens=args.max_batch_size) + num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks() print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}') # Create a controller for each node. From f5d1e2cdaf224bdf501d49f64696abc82d1b96e3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 12 Mar 2023 07:12:38 +0000 Subject: [PATCH 14/15] Add psutil to README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1ae1ff6aa665..7cce45b9efbc 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ ## Installation ```bash -pip install cmake torch transformers +pip install psutil numpy torch transformers pip install flash-attn # This may take up to 10 mins. pip install -e . ``` From cc63c24a48addae3f5321db344bdd0d8dfbf586b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 12 Mar 2023 07:22:16 +0000 Subject: [PATCH 15/15] Fix comment --- cacheflow/models/memory_analyzer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index ed8c69a0e0ae..6af7b25f60b3 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -78,10 +78,10 @@ def _get_max_act_size( ) -> int: # TODO(woosuk): Support tensor parallelism. # NOTE: We approxmiately calculate the maximum activation size by - # 1) finding a maximum activation tensor size, and + # 1) estimating the maximum activation tensor size during inference, and # 2) multiplying it by 4. - # Here, we assume FlashAttention is used, and thus the attention maps - # are not materialized in GPU DRAM. + # Here, we assume that FlashAttention is used and + # thus the attention maps are never materialized in GPU DRAM. qkv = 3 * (max_num_batched_tokens * self.hidden_size) ffn = max_num_batched_tokens * self.ffn_size max_act = 4 * max(qkv, ffn)