Skip to content

Fix lints #262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class BuilderArgs:
setup_caches: bool = False
use_tp: bool = False
is_chat_model: bool = False

def __post_init__(self):
if not (
(self.checkpoint_path and self.checkpoint_path.is_file())
Expand Down Expand Up @@ -77,15 +77,15 @@ def from_args(cls, args): # -> BuilderArgs:
args.checkpoint_dir,
args.dso_path,
args.pte_path,
args.gguf_path
args.gguf_path,
]:
path = str(path)
if path.endswith('/'):
if path.endswith("/"):
path = path[:-1]
path_basename = os.path.basename(path)
if "chat" in path_basename:
is_chat_model = True

return cls(
checkpoint_path=args.checkpoint_path,
checkpoint_dir=args.checkpoint_dir,
Expand Down Expand Up @@ -189,6 +189,7 @@ def _set_gguf_kwargs(builder_args, is_et, context: str):
if is_et:
builder_args.gguf_kwargs["load_as_quantized"] = False


def _unset_gguf_kwargs(builder_args):
builder_args.gguf_kwargs = None

Expand Down Expand Up @@ -264,6 +265,7 @@ def _load_model(builder_args):

if builder_args.use_tp:
from tp import apply_tp

print("Applying tensor parallel to model ...")
apply_tp(model)

Expand Down
28 changes: 13 additions & 15 deletions build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,21 @@
import copy
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict
from typing import Any

import gguf

import torch
import torch.nn as nn

wd = Path(__file__).parent.resolve()
sys.path.append(str(wd))

from gguf import GGUFValueType, ReaderTensor
from quantize import (
group_dequantize_tensor_from_qparams,
pack_scales_and_zeros,
WeightOnlyInt4Linear,
)

from build.gguf_util import F16, F32, Q4_0, Q6_K, to_float
from gguf import GGUFValueType
from model import ModelArgs, Transformer
from quantize import pack_scales_and_zeros, WeightOnlyInt4Linear

from build.gguf_util import Q4_0, to_float

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -116,9 +110,7 @@ def load_model(gguf_file: str) -> torch.nn.Module:
metadata = _get_metadata(reader)

arch = metadata["general.architecture"]
assert (
arch == "llama"
), "Only LLaMa models are supported by this converter."
assert arch == "llama", "Only LLaMa models are supported by this converter."

model_args = ModelArgs(
dim=metadata[f"{arch}.embedding_length"],
Expand All @@ -139,7 +131,13 @@ def load_model(gguf_file: str) -> torch.nn.Module:
return model


def load_model_and_state_dict(gguf_file: str, *, load_state_dict: bool = True, load_as_quantized: bool = True, inner_k_tiles = 8) -> torch.nn.Module:
def load_model_and_state_dict(
gguf_file: str,
*,
load_state_dict: bool = True,
load_as_quantized: bool = True,
inner_k_tiles=8,
) -> torch.nn.Module:
"""
Parses the GGUF file and returns an nn.Module on meta device along with a state_dict
that can be loaded into it.
Expand Down
1 change: 1 addition & 0 deletions build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def from_params(cls, params_path: str):
@classmethod
def from_gguf(cls, gguf_path: str, **kwargs):
from build.gguf_loader import load_model_and_state_dict

model, state_dict = load_model_and_state_dict(gguf_path, **kwargs)
if state_dict != {}:
model.load_state_dict(state_dict, assign=True)
Expand Down
15 changes: 7 additions & 8 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import itertools

import logging
import os
import sys
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -109,12 +108,12 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):


def prefill(
model: Transformer,
x: torch.Tensor,
input_pos: torch.Tensor,
*,
sequential_prefill = True,
**sampling_kwargs
model: Transformer,
x: torch.Tensor,
input_pos: torch.Tensor,
*,
sequential_prefill=True,
**sampling_kwargs,
) -> torch.Tensor:
logging.debug(f"x: {x}, input_pos: {input_pos}")
width = x.size(1)
Expand Down Expand Up @@ -348,7 +347,7 @@ def _main(
is_speculative = speculative_builder_args.checkpoint_path is not None

if generator_args.chat_mode and not builder_args.is_chat_model:
# This is not a log message, it's a dangerous condition message
# This is not a log message, it's a dangerous condition message
# that we must ensure is displayed
print(
"""
Expand Down