Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 45 additions & 37 deletions torch_sim/models/orb.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@
system_config = SystemConfig(radius=6.0, max_num_neighbors=20)

# Handle batch information if present
if state.batch is not None:
n_node = torch.bincount(state.batch)
else:
n_node = torch.tensor([len(state.positions)])
n_node = torch.bincount(state.batch)

Check warning on line 107 in torch_sim/models/orb.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/orb.py#L107

Added line #L107 was not covered by tests

# Set default dtype if not provided
output_dtype = torch.get_default_dtype() if output_dtype is None else output_dtype
Expand Down Expand Up @@ -148,45 +145,46 @@
if wrap and (torch.any(row_vector_cell != 0) and torch.any(pbc)):
positions = feat_util.batch_map_to_pbc_cell(positions, row_vector_cell, n_node)

# Compute edges of the graph
edge_index, edge_vectors, unit_shifts, batch_num_edges = (
feat_util.batch_compute_pbc_radius_graph(
positions=positions,
cells=row_vector_cell,
pbc=pbc.unsqueeze(0).repeat(len(n_node), 1),
radius=system_config.radius,
n_node=n_node,
max_number_neighbors=torch.tensor([max_num_neighbors] * len(n_node)),
edge_method=edge_method,
half_supercell=half_supercell,
device=device,
)
)
senders, receivers = edge_index[0], edge_index[1]

n_systems = state.batch.max().item() + 1

# Prepare lists to collect data from each system
all_edges = []
all_vectors = []
all_unit_shifts = []
num_edges = []

Check warning on line 154 in torch_sim/models/orb.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/orb.py#L151-L154

Added lines #L151 - L154 were not covered by tests
node_feats_list = []
edge_feats_list = []
graph_feats_list = []
system_edges = torch.repeat_interleave(
torch.arange(n_systems, device=state.device), batch_num_edges
)

# Process each system in a single loop
offset = 0

Check warning on line 160 in torch_sim/models/orb.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/orb.py#L160

Added line #L160 was not covered by tests
for i in range(n_systems):
batch_mask = state.batch == i
system_edge_mask = system_edges == i
try:
positions_per_system = positions[batch_mask]
atomic_numbers_per_system = atomic_numbers[batch_mask]
atomic_numbers_embedding_per_system = atomic_numbers_embedding[batch_mask]
edge_vectors_per_system = edge_vectors[system_edge_mask]
unit_shifts_per_system = unit_shifts[system_edge_mask]
except Exception: # noqa: BLE001
import pdb # noqa: T100

pdb.set_trace() # noqa: T100

positions_per_system = positions[batch_mask]
atomic_numbers_per_system = atomic_numbers[batch_mask]
atomic_numbers_embedding_per_system = atomic_numbers_embedding[batch_mask]

Check warning on line 165 in torch_sim/models/orb.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/orb.py#L163-L165

Added lines #L163 - L165 were not covered by tests
cell_per_system = row_vector_cell[i]
pbc_per_system = pbc

# Compute edges directly for this system
edges, vectors, unit_shifts = feat_util.compute_pbc_radius_graph(

Check warning on line 170 in torch_sim/models/orb.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/orb.py#L170

Added line #L170 was not covered by tests
positions=positions_per_system,
cell=cell_per_system,
pbc=pbc_per_system,
radius=system_config.radius,
max_number_neighbors=max_num_neighbors,
edge_method=edge_method,
half_supercell=half_supercell,
device=device,
)

# Adjust indices for the global batch
all_edges.append(edges + offset)
all_vectors.append(vectors)
all_unit_shifts.append(unit_shifts)
num_edges.append(len(edges[0]))

Check warning on line 185 in torch_sim/models/orb.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/orb.py#L182-L185

Added lines #L182 - L185 were not covered by tests

# Calculate lattice parameters
lattice_per_system = torch.from_numpy(
cell_to_cellpar(cell_per_system.squeeze(0).cpu().numpy())
)
Expand All @@ -202,8 +200,8 @@
}

edge_feats = {
"vectors": edge_vectors_per_system,
"unit_shifts": unit_shifts_per_system,
"vectors": vectors,
"unit_shifts": unit_shifts,
}

graph_feats = {
Expand All @@ -221,6 +219,16 @@
edge_feats_list.append(edge_feats)
graph_feats_list.append(graph_feats)

# Update offset for next system
offset += len(positions_per_system)

Check warning on line 223 in torch_sim/models/orb.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/orb.py#L223

Added line #L223 was not covered by tests

# Concatenate all the edge data
edge_index = torch.cat(all_edges, dim=1)
unit_shifts = torch.cat(all_unit_shifts, dim=0)
batch_num_edges = torch.tensor(num_edges, dtype=torch.int64, device=device)

Check warning on line 228 in torch_sim/models/orb.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/orb.py#L226-L228

Added lines #L226 - L228 were not covered by tests

senders, receivers = edge_index[0], edge_index[1]

Check warning on line 230 in torch_sim/models/orb.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/orb.py#L230

Added line #L230 was not covered by tests

# Create and return AtomGraphs object
return AtomGraphs(
senders=senders,
Expand Down