Skip to content

Reduce startup time for SAM2 AMG by using torch.export #1358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions examples/sam2_amg_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from server import model_type_to_paths
from server import MODEL_TYPES_TO_MODEL
from server import set_fast
from server import set_aot_fast
from server import load_aot_fast
from server import set_furious
from torchao._models.sam2.build_sam import build_sam2
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
Expand All @@ -22,17 +24,20 @@ def main_docstring():
"""


def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False):
def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""):
device = "cuda"
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
if verbose:
print(f"Loading model {sam2_checkpoint} with config {model_cfg}")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
if fast:
set_fast(mask_generator)
if furious:
set_furious(mask_generator)
if load_fast:
load_aot_fast(mask_generator, load_fast)
if fast:
set_fast(mask_generator, load_fast)

image_tensor = file_bytes_to_image_tensor(input_bytes)
if verbose:
print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.")
Expand All @@ -50,7 +55,7 @@ def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=102
buf.seek(0)
return buf.getvalue()

def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False):
def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""):
input_bytes = bytearray(open(input_path, 'rb').read())
output_bytes = main_headless(checkpoint_path,
model_type,
Expand All @@ -59,7 +64,8 @@ def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=
output_format=output_format,
verbose=verbose,
fast=fast,
furious=furious)
furious=furious,
load_fast=load_fast)
with open(output_path, "wb") as file:
file.write(output_bytes)

Expand Down
196 changes: 183 additions & 13 deletions examples/sam2_amg_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,14 +332,175 @@ def model_type_to_paths(checkpoint_path, model_type):
model_cfg = f"configs/sam2.1/{MODEL_TYPES_TO_CONFIG[model_type]}"
return sam2_checkpoint, model_cfg

def set_fast(mask_generator):
# TODO: Using CUDA graphs can cause numerical differences?
mask_generator.predictor.model.image_encoder = torch.compile(
mask_generator.predictor.model.image_encoder,
mode="max-autotune",
fullgraph=True,
dynamic=False,

def aot_compile(model_directory, name, fn, sample_args):
path = Path(model_directory) / Path(f"{name}.pt2")
print(f"Saving at {path=}")
options = {
"max_autotune": True,
"triton.cudagraphs": True,
}

exported = torch.export.export_for_inference(fn, sample_args)
output_path = torch._inductor.aoti_compile_and_package(
exported,
package_path=str(path),
inductor_configs=options,
)
return output_path


def aot_load(path):
return torch._export.aot_load(path, "cuda")

class FunctionModel(torch.nn.Module):

def __init__(self, module, fn_name):
super().__init__()
self.module = module
self.fn_name = fn_name

def forward(self, *args):
return getattr(self.module, self.fn_name)(*args)


def set_aot_fast(mask_generator, model_directory):
example_input = torch.empty(1, 3, 1024, 1024)
example_input = example_input.to(mask_generator.predictor._image_dtype)
example_input = (example_input.to(mask_generator.predictor.device),)
aot_compile(model_directory,
"sam2_image_encoder",
mask_generator.predictor.model.image_encoder,
example_input)

# NOTE: THIS DOESN'T WORK YET!
# example_input_0_0 = torch.empty(1, 32, 256, 256, dtype=torch.float16, device=mask_generator.predictor.device)
# example_input_0_1 = torch.empty(1, 64, 128, 128, dtype=torch.float16, device=mask_generator.predictor.device)
# example_input_1 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device)
# example_input_2 = torch.empty(1024, 1, 2, dtype=torch.float32, device=mask_generator.predictor.device)
# example_input_3 = torch.empty(1024, 1, dtype=torch.int32, device=mask_generator.predictor.device)
# example_input = ([example_input_0_0, example_input_0_1],
# example_input_1,
# example_input_2,
# example_input_3,
# None,
# None,
# True,
# True,
# -1)
# mask_generator.forward = mask_generator.predictor._predict_masks_with_features
# mask_generator(*example_input)
# aot_compile("sam2__predict_masks_with_features",
# mask_generator,
# example_input)

# example_input_2 = torch.empty(1024, 1, 2, dtype=torch.float32, device=mask_generator.predictor.device)
# example_input_3 = torch.empty(1024, 1, dtype=torch.int32, device=mask_generator.predictor.device)
# aot_compile("sam2_sam_prompt_encoder",
# mask_generator.predictor.model.sam_prompt_encoder,
# ((example_input_2, example_input_3),
# None,
# None))

# NOTE: THIS DOESN'T WORK YET!
# example_input_0 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device)
# example_input_1 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device)
# example_input_2 = torch.empty(1024, 2, 256, dtype=torch.float32, device=mask_generator.predictor.device)
# example_input_3 = torch.empty(1024, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device)

# example_input_4_0 = torch.empty(1, 32, 256, 256, dtype=torch.float16, device=mask_generator.predictor.device)
# example_input_4_1 = torch.empty(1, 64, 128, 128, dtype=torch.float16, device=mask_generator.predictor.device)

# example_input = (example_input_0,
# example_input_1,
# example_input_2,
# example_input_3,
# True,
# True,
# [example_input_4_0, example_input_4_1])
# print("Example")
# mask_generator.predictor.model.sam_mask_decoder(*example_input)
# print("Example done")
# aot_compile("sam2_sam_mask_decoder",
# mask_generator.predictor.model.sam_mask_decoder,
# example_input)

# example_input_0 = torch.empty(1024, 256, 64, 64, dtype=torch.float16, device=mask_generator.predictor.device)
# example_input_1 = torch.empty(1024, 256, 64, 64, dtype=torch.float16, device=mask_generator.predictor.device)
# example_input_2 = torch.empty(1024, 8, 256, dtype=torch.float16, device=mask_generator.predictor.device)
# example_input = (example_input_0, example_input_1, example_input_2)

# mask_generator.predictor.model.sam_mask_decoder.transformer(*example_input)
# aot_compile("sam2_sam_mask_decoder_transformer",
# mask_generator.predictor.model.sam_mask_decoder.transformer,
# example_input)




class LoadedModel(torch.nn.Module):

def __init__(self, aoti_compiled_model):
super().__init__()
self.aoti_compiled_model = aoti_compiled_model

def forward(self, *args):
return self.aoti_compiled_model(*args)

class LoadedDecoder(torch.nn.Module):

def __init__(self, aoti_compiled_model, other):
super().__init__()
self.aoti_compiled_model = aoti_compiled_model
self.other = other

def forward(self, *args):
return self.aoti_compiled_model(*args)

def get_dense_pe(self, *args, **kwargs) -> torch.Tensor:
return self.other.get_dense_pe(*args, **kwargs)

def load_aot_fast(mask_generator, model_directory):
t0 = time.time()
path = Path(model_directory) / Path(f"sam2_image_encoder.pt2")
assert path.exists(), f"Expected {path} to exist."
print(f"Start load from {path}")
pkg = torch._inductor.aoti_load_package(str(path))
pkg_m = LoadedModel(pkg)
mask_generator.predictor.model.image_encoder = pkg_m

# NOTE: This doesn't work yet!
# pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2__predict_masks_with_features.pt2"))
# pkg_m = LoadedModel(pkg)
# mask_generator.predictor._predict_masks_with_features = pkg_m.forward

# pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_prompt_encoder.pt2"))
# pkg_m = LoadedDecoder(pkg, mask_generator.predictor.model.sam_prompt_encoder)
# mask_generator.predictor.model.sam_prompt_encoder = pkg_m

# NOTE: This doesn't work yet!
# pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_mask_decoder.pt2"))
# pkg_m = LoadedModel(pkg)
# pkg_m.conv_s0 = mask_generator.predictor.model.sam_mask_decoder.conv_s0
# pkg_m.conv_s1 = mask_generator.predictor.model.sam_mask_decoder.conv_s1
# mask_generator.predictor.model.sam_mask_decoder = pkg_m

# pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_mask_decoder_transformer.pt2"))
# pkg_m = LoadedModel(pkg)
# mask_generator.predictor.model.sam_mask_decoder.transformer = pkg_m

print(f"End load. Took {time.time() - t0}s")


def set_fast(mask_generator, load_fast=""):
if load_fast == "":
# TODO: Using CUDA graphs can cause numerical differences?
mask_generator.predictor.model.image_encoder = torch.compile(
mask_generator.predictor.model.image_encoder,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)

mask_generator.predictor._predict_masks = torch.compile(
mask_generator.predictor._predict_masks,
Expand Down Expand Up @@ -381,7 +542,9 @@ def main(checkpoint_path,
port=5000,
host="127.0.0.1",
dry=False,
batch_size=1):
batch_size=1,
load_fast="",
save_fast=""):
if verbose:
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
Expand Down Expand Up @@ -410,25 +573,32 @@ def main(checkpoint_path,
logging.info(f"Using {points_per_batch} points_per_batch")
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")

if load_fast != "":
load_aot_fast(mask_generator, load_fast)

if save_fast != "":
assert load_fast == "", "Can't save compiled models while loading them with --load-fast."
assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible."
print(f"Saving compiled models under directory {save_fast}")
set_aot_fast(mask_generator, save_fast)

if fast:
assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible."
set_fast(mask_generator)
set_fast(mask_generator, load_fast)

if furious:
set_furious(mask_generator)

# since autoquant is replicating what furious mode is doing, don't use these two together
elif use_autoquant:
from torchao import autoquant
from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
mask_generator.predictor.model = autoquant(mask_generator.predictor.model, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)
mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)

mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40)
# mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40)
# NOTE: Not baseline feature
mask_generator.predictor._transforms_device = mask_generator.predictor.device
torch.set_float32_matmul_precision('high')


with open('dog.jpg', 'rb') as f:
image_tensor = file_bytes_to_image_tensor(bytearray(f.read()))

Expand Down
4 changes: 2 additions & 2 deletions torchao/_models/sam2/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)


class SAM2AutomaticMaskGenerator:
class SAM2AutomaticMaskGenerator(torch.nn.Module):
def __init__(
self,
model: SAM2Base,
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
multimask_output (bool): Whether to output multimask at each point of the grid.
"""

super().__init__()
assert (points_per_side is None) != (
point_grids is None
), "Exactly one of points_per_side or point_grid must be provided."
Expand Down
2 changes: 1 addition & 1 deletion torchao/_models/sam2/sam2_image_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchao._models.sam2.utils.transforms import SAM2Transforms


class SAM2ImagePredictor:
class SAM2ImagePredictor(torch.nn.Module):
def __init__(
self,
sam_model: SAM2Base,
Expand Down