Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions hloc/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@
"resize_max": 1600,
},
},
"orb": {
"output": "feats-orb",
"model": {"name": "orb"},
"preprocessing": {
"grayscale": True,
"resize_max": 1600,
},
},
"sosnet": {
"output": "feats-sosnet",
"model": {"name": "dog", "descriptor": "sosnet"},
Expand Down
100 changes: 100 additions & 0 deletions hloc/extractors/orb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import cv2
import numpy as np
import torch

from ..utils.base_model import BaseModel

EPS = 1e-6


class ORB(BaseModel):
default_conf = {
"options": {
"nfeatures": 5000,
"scaleFactor": 1.2,
"nlevels": 8,
"edgeThreshold": 31,
"firstLevel": 0,
"WTA_K": 2,
"scoreType": cv2.ORB_HARRIS_SCORE, # or cv2.ORB_FAST_SCORE
"patchSize": 31,
"fastThreshold": 20,
},
"descriptor": "orb",
"max_keypoints": -1,
}
required_inputs = ["image"]
detection_noise = 1.0
max_batch_size = 4096

def _init(self, conf):
if conf["descriptor"] != "orb":
raise ValueError(f'Unknown descriptor: {conf["descriptor"]}')
self.orb = None
self.dummy_param = torch.nn.Parameter(torch.empty(0))

def _make_orb(self):
opts = self.conf["options"]

self.orb = cv2.ORB_create(
nfeatures=int(opts.get("nfeatures", 5000)),
scaleFactor=float(opts.get("scaleFactor", 1.2)),
nlevels=int(opts.get("nlevels", 8)),
edgeThreshold=int(opts.get("edgeThreshold", 31)),
firstLevel=int(opts.get("firstLevel", 0)),
WTA_K=int(opts.get("WTA_K", 2)),
scoreType=int(opts.get("scoreType", cv2.ORB_HARRIS_SCORE)),
patchSize=int(opts.get("patchSize", 31)),
fastThreshold=int(opts.get("fastThreshold", 20)),
)

def _forward(self, data):
image = data["image"]

image_np = image.cpu().numpy()[0, 0]
assert image.shape[1] == 1, "ORB expects a single-channel image"
assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS

if self.orb is None:
self._make_orb()

# Greyscale
img_u8 = np.clip(image_np * 255.0 + 0.5, 0, 255).astype(np.uint8)

keypoints, descriptors = self.orb.detectAndCompute(img_u8, None)

pts = np.array([kp.pt for kp in keypoints], dtype=np.float32)
sizes = np.array([kp.size for kp in keypoints], dtype=np.float32)
scales = sizes / 2.0
angles = np.array([kp.angle for kp in keypoints], dtype=np.float32)
responses = np.array([kp.response for kp in keypoints], dtype=np.float32)

# [N, 32] binary ORB
if descriptors is None:
descriptors = np.empty((0, 32), dtype=np.uint8)

keypoints = torch.from_numpy(pts)
scales = torch.from_numpy(scales)
oris = torch.from_numpy(angles)
scores = torch.from_numpy(responses)
descriptors = torch.from_numpy(descriptors) # (N,32) uint8

if (
self.conf["max_keypoints"] != -1
and len(keypoints) > self.conf["max_keypoints"]
):
k = int(self.conf["max_keypoints"])
vals, idxs = torch.topk(scores, k)
keypoints = keypoints[idxs]
scales = scales[idxs]
oris = oris[idxs]
scores = vals
descriptors = descriptors[idxs]

return {
"keypoints": keypoints[None], # [1, N, 2] (x, y)
"scales": scales[None], # [1, N]
"oris": oris[None], # [1, N]
"scores": scores[None], # [1, N]
"descriptors": descriptors.T[None], # [1, 32, N]
}
7 changes: 7 additions & 0 deletions hloc/match_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@
"do_mutual_check": True,
},
},
"orb": {
"output": "matches-orb",
"model": {
"name": "orb_match",
"do_mutual_check": True,
},
},
"adalam": {
"output": "matches-adalam",
"model": {"name": "adalam"},
Expand Down
83 changes: 83 additions & 0 deletions hloc/matchers/orb_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import cv2
import numpy as np
import torch

from ..utils.base_model import BaseModel


def tens_to_cv(x):
if isinstance(x, torch.Tensor):
x = x.detach().cpu()
if x.ndim == 3 and x.shape[0] == 1:
x = x.squeeze(0)
if x.ndim == 2 and x.shape[0] in (8, 16, 32, 64, 128, 256):
x = x.transpose(0, 1)
if isinstance(x, torch.Tensor):
x = x.contiguous().to(torch.uint8).numpy()
else:
x = np.ascontiguousarray(x, dtype=np.uint8)
return x # shape (N, 32)


class BinaryNearestNeighbor(BaseModel):
default_conf = {
"ratio_threshold": None,
"distance_threshold_bits": None,
"do_mutual_check": True,
}

required_inputs = ["descriptors0", "scores0", "descriptors1", "scores1"]

def _init(self, conf):
lut = torch.arange(256, dtype=torch.uint8)
lut = (
(lut & 1)
+ ((lut >> 1) & 1)
+ ((lut >> 2) & 1)
+ ((lut >> 3) & 1)
+ ((lut >> 4) & 1)
+ ((lut >> 5) & 1)
+ ((lut >> 6) & 1)
+ ((lut >> 7) & 1)
)
self.register_buffer("_popcnt8", lut.to(torch.uint8), persistent=False)

self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)

def _forward(self, data):
d0 = data["descriptors0"]
d1 = data["descriptors1"]

d0 = tens_to_cv(d0)
d1 = tens_to_cv(d1)

D0, N0 = d0.shape
_, N1 = d1.shape
if N0 == 0 or N1 == 0:
device = d0.device
return {
"matches0": torch.full((1, N0), -1, dtype=torch.long, device=device),
"matching_scores0": torch.zeros(
(1, N0), dtype=torch.float32, device=device
),
}

matches = self.matcher.match(d0, d1)

N0, Dbytes = d0.shape
Dbits = 8 * Dbytes

matches0 = torch.full((N0,), -1, dtype=torch.long)
matching_scores0 = torch.zeros((N0,), dtype=torch.float32)

for m in matches:
q = m.queryIdx
t = m.trainIdx
dist = float(m.distance)
matches0[q] = t
matching_scores0[q] = 1.0 - dist / Dbits

matches0 = matches0.unsqueeze(0) # [1, N0]
matching_scores0 = matching_scores0.unsqueeze(0)

return {"matches0": matches0, "matching_scores0": matching_scores0}