Skip to content

Commit a996fe3

Browse files
mergennachinmalfet
authored andcommitted
Fix lints (#262)
1 parent 4cfda2e commit a996fe3

File tree

4 files changed

+27
-27
lines changed

4 files changed

+27
-27
lines changed

build/builder.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class BuilderArgs:
3737
setup_caches: bool = False
3838
use_tp: bool = False
3939
is_chat_model: bool = False
40-
40+
4141
def __post_init__(self):
4242
if not (
4343
(self.checkpoint_path and self.checkpoint_path.is_file())
@@ -77,15 +77,15 @@ def from_args(cls, args): # -> BuilderArgs:
7777
args.checkpoint_dir,
7878
args.dso_path,
7979
args.pte_path,
80-
args.gguf_path
80+
args.gguf_path,
8181
]:
8282
path = str(path)
83-
if path.endswith('/'):
83+
if path.endswith("/"):
8484
path = path[:-1]
8585
path_basename = os.path.basename(path)
8686
if "chat" in path_basename:
8787
is_chat_model = True
88-
88+
8989
return cls(
9090
checkpoint_path=args.checkpoint_path,
9191
checkpoint_dir=args.checkpoint_dir,
@@ -189,6 +189,7 @@ def _set_gguf_kwargs(builder_args, is_et, context: str):
189189
if is_et:
190190
builder_args.gguf_kwargs["load_as_quantized"] = False
191191

192+
192193
def _unset_gguf_kwargs(builder_args):
193194
builder_args.gguf_kwargs = None
194195

@@ -264,6 +265,7 @@ def _load_model(builder_args):
264265

265266
if builder_args.use_tp:
266267
from tp import apply_tp
268+
267269
print("Applying tensor parallel to model ...")
268270
apply_tp(model)
269271

build/gguf_loader.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,21 @@
88
import copy
99
import logging
1010
import sys
11-
from dataclasses import dataclass
1211
from pathlib import Path
13-
from typing import Any, Dict
12+
from typing import Any
1413

1514
import gguf
1615

1716
import torch
18-
import torch.nn as nn
1917

2018
wd = Path(__file__).parent.resolve()
2119
sys.path.append(str(wd))
2220

23-
from gguf import GGUFValueType, ReaderTensor
24-
from quantize import (
25-
group_dequantize_tensor_from_qparams,
26-
pack_scales_and_zeros,
27-
WeightOnlyInt4Linear,
28-
)
29-
30-
from build.gguf_util import F16, F32, Q4_0, Q6_K, to_float
21+
from gguf import GGUFValueType
3122
from model import ModelArgs, Transformer
23+
from quantize import pack_scales_and_zeros, WeightOnlyInt4Linear
24+
25+
from build.gguf_util import Q4_0, to_float
3226

3327
logger: logging.Logger = logging.getLogger(__name__)
3428

@@ -116,9 +110,7 @@ def load_model(gguf_file: str) -> torch.nn.Module:
116110
metadata = _get_metadata(reader)
117111

118112
arch = metadata["general.architecture"]
119-
assert (
120-
arch == "llama"
121-
), "Only LLaMa models are supported by this converter."
113+
assert arch == "llama", "Only LLaMa models are supported by this converter."
122114

123115
model_args = ModelArgs(
124116
dim=metadata[f"{arch}.embedding_length"],
@@ -139,7 +131,13 @@ def load_model(gguf_file: str) -> torch.nn.Module:
139131
return model
140132

141133

142-
def load_model_and_state_dict(gguf_file: str, *, load_state_dict: bool = True, load_as_quantized: bool = True, inner_k_tiles = 8) -> torch.nn.Module:
134+
def load_model_and_state_dict(
135+
gguf_file: str,
136+
*,
137+
load_state_dict: bool = True,
138+
load_as_quantized: bool = True,
139+
inner_k_tiles=8,
140+
) -> torch.nn.Module:
143141
"""
144142
Parses the GGUF file and returns an nn.Module on meta device along with a state_dict
145143
that can be loaded into it.

build/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def from_params(cls, params_path: str):
248248
@classmethod
249249
def from_gguf(cls, gguf_path: str, **kwargs):
250250
from build.gguf_loader import load_model_and_state_dict
251+
251252
model, state_dict = load_model_and_state_dict(gguf_path, **kwargs)
252253
if state_dict != {}:
253254
model.load_state_dict(state_dict, assign=True)

generate.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import itertools
88

99
import logging
10-
import os
1110
import sys
1211
import time
1312
from dataclasses import dataclass
@@ -109,12 +108,12 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
109108

110109

111110
def prefill(
112-
model: Transformer,
113-
x: torch.Tensor,
114-
input_pos: torch.Tensor,
115-
*,
116-
sequential_prefill = True,
117-
**sampling_kwargs
111+
model: Transformer,
112+
x: torch.Tensor,
113+
input_pos: torch.Tensor,
114+
*,
115+
sequential_prefill=True,
116+
**sampling_kwargs,
118117
) -> torch.Tensor:
119118
logging.debug(f"x: {x}, input_pos: {input_pos}")
120119
width = x.size(1)
@@ -348,7 +347,7 @@ def _main(
348347
is_speculative = speculative_builder_args.checkpoint_path is not None
349348

350349
if generator_args.chat_mode and not builder_args.is_chat_model:
351-
# This is not a log message, it's a dangerous condition message
350+
# This is not a log message, it's a dangerous condition message
352351
# that we must ensure is displayed
353352
print(
354353
"""

0 commit comments

Comments
 (0)