Skip to content

Commit f4f2fa6

Browse files
committed
Initial setup of SAM architecture
1 parent 76f02f8 commit f4f2fa6

File tree

1 file changed

+73
-0
lines changed
  • tiatoolbox/models/architecture

1 file changed

+73
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Define SAM architecture."""
2+
3+
from __future__ import annotations
4+
5+
from collections import OrderedDict
6+
7+
import cv2
8+
import numpy as np
9+
import torch
10+
import torch.nn.functional as F # noqa: N812
11+
from skimage import morphology
12+
from torch import nn
13+
14+
from tiatoolbox.utils import misc
15+
16+
from sam2.build_sam import build_sam2
17+
from sam2.sam2_image_predictor import SAM2ImagePredictor
18+
19+
checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
20+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
21+
22+
sam_model = build_sam2(model_cfg, checkpoint)
23+
24+
predictor = SAM2ImagePredictor(sam_model)
25+
26+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
27+
predictor.set_image(<your_image>)
28+
masks, _, _ = predictor.predict(<input_prompts>)
29+
30+
class SAM(ModelABC):
31+
32+
def __init__(self: ModelABC) -> None:
33+
"""Initialize Abstract class ModelABC."""
34+
super().__init__()
35+
self._postproc = self.postproc
36+
self._preproc = self.preproc
37+
38+
def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None:
39+
"""Torch method, this contains logic for using layers defined in init."""
40+
... # pragma: no cover
41+
42+
def infer_batch(
43+
model: torch.nn.Module,
44+
batch_data: np.ndarray,
45+
*,
46+
on_gpu: bool,
47+
) -> None:
48+
"""Run inference on an input batch.
49+
50+
Contains logic for forward operation as well as I/O aggregation.
51+
52+
Args:
53+
model (nn.Module):
54+
PyTorch defined model.
55+
batch_data (np.ndarray):
56+
A batch of data generated by
57+
`torch.utils.data.DataLoader`.
58+
on_gpu (bool):
59+
Whether to run inference on a GPU.
60+
61+
"""
62+
... # pragma: no cover
63+
64+
@staticmethod
65+
def preproc(image: np.ndarray) -> np.ndarray:
66+
"""Define the pre-processing of this class of model."""
67+
return image
68+
69+
@staticmethod
70+
def postproc(image: np.ndarray) -> np.ndarray:
71+
"""Define the post-processing of this class of model."""
72+
return image
73+

0 commit comments

Comments
 (0)