diff --git a/build/builder.py b/build/builder.py index 10d6c3717..6f893f94a 100644 --- a/build/builder.py +++ b/build/builder.py @@ -37,7 +37,7 @@ class BuilderArgs: setup_caches: bool = False use_tp: bool = False is_chat_model: bool = False - + def __post_init__(self): if not ( (self.checkpoint_path and self.checkpoint_path.is_file()) @@ -77,15 +77,15 @@ def from_args(cls, args): # -> BuilderArgs: args.checkpoint_dir, args.dso_path, args.pte_path, - args.gguf_path + args.gguf_path, ]: path = str(path) - if path.endswith('/'): + if path.endswith("/"): path = path[:-1] path_basename = os.path.basename(path) if "chat" in path_basename: is_chat_model = True - + return cls( checkpoint_path=args.checkpoint_path, checkpoint_dir=args.checkpoint_dir, @@ -189,6 +189,7 @@ def _set_gguf_kwargs(builder_args, is_et, context: str): if is_et: builder_args.gguf_kwargs["load_as_quantized"] = False + def _unset_gguf_kwargs(builder_args): builder_args.gguf_kwargs = None @@ -264,6 +265,7 @@ def _load_model(builder_args): if builder_args.use_tp: from tp import apply_tp + print("Applying tensor parallel to model ...") apply_tp(model) diff --git a/build/gguf_loader.py b/build/gguf_loader.py index 43603d2a7..bd40ba642 100644 --- a/build/gguf_loader.py +++ b/build/gguf_loader.py @@ -8,27 +8,21 @@ import copy import logging import sys -from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict +from typing import Any import gguf import torch -import torch.nn as nn wd = Path(__file__).parent.resolve() sys.path.append(str(wd)) -from gguf import GGUFValueType, ReaderTensor -from quantize import ( - group_dequantize_tensor_from_qparams, - pack_scales_and_zeros, - WeightOnlyInt4Linear, -) - -from build.gguf_util import F16, F32, Q4_0, Q6_K, to_float +from gguf import GGUFValueType from model import ModelArgs, Transformer +from quantize import pack_scales_and_zeros, WeightOnlyInt4Linear + +from build.gguf_util import Q4_0, to_float logger: logging.Logger = logging.getLogger(__name__) @@ -116,9 +110,7 @@ def load_model(gguf_file: str) -> torch.nn.Module: metadata = _get_metadata(reader) arch = metadata["general.architecture"] - assert ( - arch == "llama" - ), "Only LLaMa models are supported by this converter." + assert arch == "llama", "Only LLaMa models are supported by this converter." model_args = ModelArgs( dim=metadata[f"{arch}.embedding_length"], @@ -139,7 +131,13 @@ def load_model(gguf_file: str) -> torch.nn.Module: return model -def load_model_and_state_dict(gguf_file: str, *, load_state_dict: bool = True, load_as_quantized: bool = True, inner_k_tiles = 8) -> torch.nn.Module: +def load_model_and_state_dict( + gguf_file: str, + *, + load_state_dict: bool = True, + load_as_quantized: bool = True, + inner_k_tiles=8, +) -> torch.nn.Module: """ Parses the GGUF file and returns an nn.Module on meta device along with a state_dict that can be loaded into it. diff --git a/build/model.py b/build/model.py index ae6038707..e4bea09b6 100644 --- a/build/model.py +++ b/build/model.py @@ -248,6 +248,7 @@ def from_params(cls, params_path: str): @classmethod def from_gguf(cls, gguf_path: str, **kwargs): from build.gguf_loader import load_model_and_state_dict + model, state_dict = load_model_and_state_dict(gguf_path, **kwargs) if state_dict != {}: model.load_state_dict(state_dict, assign=True) diff --git a/generate.py b/generate.py index 735ee6e43..87efcd15e 100644 --- a/generate.py +++ b/generate.py @@ -7,7 +7,6 @@ import itertools import logging -import os import sys import time from dataclasses import dataclass @@ -109,12 +108,12 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): def prefill( - model: Transformer, - x: torch.Tensor, - input_pos: torch.Tensor, - *, - sequential_prefill = True, - **sampling_kwargs + model: Transformer, + x: torch.Tensor, + input_pos: torch.Tensor, + *, + sequential_prefill=True, + **sampling_kwargs, ) -> torch.Tensor: logging.debug(f"x: {x}, input_pos: {input_pos}") width = x.size(1) @@ -348,7 +347,7 @@ def _main( is_speculative = speculative_builder_args.checkpoint_path is not None if generator_args.chat_mode and not builder_args.is_chat_model: - # This is not a log message, it's a dangerous condition message + # This is not a log message, it's a dangerous condition message # that we must ensure is displayed print( """