Skip to content

Commit 58e9076

Browse files
author
Martin Rogers
committed
Reduce for loops and keep predictions as deepsensor.prediction objects
1 parent 358b884 commit 58e9076

File tree

1 file changed

+94
-94
lines changed

1 file changed

+94
-94
lines changed

deepsensor/model/model.py

Lines changed: 94 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,6 @@ def get_coordinate_extent(ds, x1_ascend, x2_ascend) -> tuple:
818818
----------
819819
ds : Data object
820820
The dataset or data array to determine coordinate extent for.
821-
Refer to `X_t` in `predict_patchwise()` for supported types.
822821
823822
x1_ascend : bool
824823
Whether the x1 coordinates ascend (increase) from top to bottom.
@@ -937,104 +936,102 @@ def stitch_clipped_predictions(
937936
data_x1_index, data_x2_index = get_index(data_x1_coords, data_x2_coords)
938937

939938
# Iterate through patchwise predictions and slice edges prior to stitchin.
940-
patches_clipped = {var_name: [] for var_name in patch_preds[0].keys()}
939+
patches_clipped = []
941940
for i, patch_pred in enumerate(patch_preds):
942-
for var_name, data_array in patch_pred.items():
943-
if var_name in patch_pred:
944-
945-
# Get row/col index values of each patch.
946-
patch_x1_coords, patch_x2_coords= get_coordinate_extent(data_array, x1_ascend, x2_ascend)
947-
patch_x1_index, patch_x2_index = get_index(patch_x1_coords, patch_x2_coords)
948-
949-
# Calculate size of border to slice of each edge of patchwise predictions.
950-
# Initially set the size of all borders to the size of the overlap.
951-
b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0]
952-
b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1]
953-
954-
# Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column.
955-
if patch_x2_index[0] == data_x2_index[0]:
956-
b_x2_min = 0
957-
b_x2_max = b_x2_max
958-
959-
# At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch.
960-
elif patch_x2_index[1] == data_x2_index[1]:
961-
b_x2_max = 0
962-
patch_row_prev = preds[i - 1]
963-
964-
# If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels.
965-
# To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
966-
if x2_ascend:
967-
prev_patch_x2_max = get_index(
968-
patch_row_prev[var_name].coords[orig_x2_name].max(),
969-
x1=False,
970-
)
971-
b_x2_min = (
972-
prev_patch_x2_max - patch_x2_index[0]
973-
) - patch_overlap[1]
974-
975-
# If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
976-
# To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
977-
else:
978-
prev_patch_x2_min = get_index(
979-
patch_row_prev[var_name].coords[orig_x2_name].min(),
980-
x1=False,
981-
)
982-
b_x2_min = (
983-
patch_x2_index[0] - prev_patch_x2_min
984-
) - patch_overlap[1]
985-
else:
986-
b_x2_max = b_x2_max
987-
988-
# Repeat process as above for x1 coordinates.
989-
if patch_x1_index[0] == data_x1_index[0]:
990-
b_x1_min = 0
991-
992-
elif abs(patch_x1_index[1] - data_x1_index[1]) < 2:
993-
b_x1_max = 0
994-
b_x1_max = b_x1_max
995-
patch_prev = preds[i - patches_per_row]
996-
if x1_ascend:
997-
prev_patch_x1_max = get_index(
998-
patch_prev[var_name].coords[orig_x1_name].max(),
999-
x1=True,
1000-
)
1001-
b_x1_min = (
1002-
prev_patch_x1_max - patch_x1_index[0]
1003-
) - patch_overlap[0]
1004-
else:
1005-
prev_patch_x1_min = get_index(
1006-
patch_prev[var_name].coords[orig_x1_name].min(),
1007-
x1=True,
1008-
)
1009-
1010-
b_x1_min = (
1011-
prev_patch_x1_min - patch_x1_index[0]
1012-
) - patch_overlap[0]
1013-
else:
1014-
b_x1_max = b_x1_max
941+
first_key, first_value = next(iter(patch_pred.items()))
942+
# Get row/col index values of each patch.
943+
patch_x1_coords, patch_x2_coords= get_coordinate_extent(patch_pred[first_key], x1_ascend, x2_ascend)
944+
patch_x1_index, patch_x2_index = get_index(patch_x1_coords, patch_x2_coords)
945+
946+
# Calculate size of border to slice of each edge of patchwise predictions.
947+
# Initially set the size of all borders to the size of the overlap.
948+
b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0]
949+
b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1]
950+
951+
# Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column.
952+
if patch_x2_index[0] == data_x2_index[0]:
953+
b_x2_min = 0
954+
b_x2_max = b_x2_max
955+
956+
# At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch.
957+
elif patch_x2_index[1] == data_x2_index[1]:
958+
b_x2_max = 0
959+
patch_row_prev = patch_preds[i - 1]
960+
961+
# If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels.
962+
# To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
963+
if x2_ascend:
964+
prev_patch_x2_max = get_index(
965+
patch_row_prev[first_key].coords[orig_x2_name].max(),
966+
x1=False,
967+
)
968+
b_x2_min = (
969+
prev_patch_x2_max - patch_x2_index[0]
970+
) - patch_overlap[1]
1015971

1016-
patch_clip_x1_min = int(b_x1_min)
1017-
patch_clip_x1_max = int(
1018-
data_array.sizes[orig_x1_name] - b_x1_max
972+
# If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
973+
# To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
974+
else:
975+
prev_patch_x2_min = get_index(
976+
patch_row_prev[first_key].coords[orig_x2_name].min(),
977+
x1=False,
1019978
)
1020-
patch_clip_x2_min = int(b_x2_min)
1021-
patch_clip_x2_max = int(
1022-
data_array.sizes[orig_x2_name] - b_x2_max
979+
b_x2_min = (
980+
patch_x2_index[0] - prev_patch_x2_min
981+
) - patch_overlap[1]
982+
else:
983+
b_x2_max = b_x2_max
984+
985+
# Repeat process as above for x1 coordinates.
986+
if patch_x1_index[0] == data_x1_index[0]:
987+
b_x1_min = 0
988+
989+
elif abs(patch_x1_index[1] - data_x1_index[1]) < 2:
990+
b_x1_max = 0
991+
b_x1_max = b_x1_max
992+
patch_prev = patch_preds[i - patches_per_row]
993+
if x1_ascend:
994+
prev_patch_x1_max = get_index(
995+
patch_prev[first_key].coords[orig_x1_name].max(),
996+
x1=True,
1023997
)
1024-
1025-
# Slice patchwise predictions
1026-
patch_clip = data_array.isel(
1027-
**{
1028-
orig_x1_name: slice(
1029-
patch_clip_x1_min, patch_clip_x1_max
1030-
),
1031-
orig_x2_name: slice(
1032-
patch_clip_x2_min, patch_clip_x2_max
1033-
),
1034-
}
998+
b_x1_min = (
999+
prev_patch_x1_max - patch_x1_index[0]
1000+
) - patch_overlap[0]
1001+
else:
1002+
prev_patch_x1_min = get_index(
1003+
patch_prev[first_key].coords[orig_x1_name].min(),
1004+
x1=True,
10351005
)
10361006

1037-
patches_clipped[var_name].append(patch_clip)
1007+
b_x1_min = (
1008+
prev_patch_x1_min - patch_x1_index[0]
1009+
) - patch_overlap[0]
1010+
else:
1011+
b_x1_max = b_x1_max
1012+
1013+
patch_clip_x1_min = int(b_x1_min)
1014+
patch_clip_x1_max = int(
1015+
patch_pred[first_key].sizes[orig_x1_name] - b_x1_max
1016+
)
1017+
patch_clip_x2_min = int(b_x2_min)
1018+
patch_clip_x2_max = int(
1019+
patch_pred[first_key].sizes[orig_x2_name] - b_x2_max
1020+
)
1021+
1022+
# Define slicing parameters
1023+
slicing_params = {
1024+
orig_x1_name: slice(patch_clip_x1_min, patch_clip_x1_max),
1025+
orig_x2_name: slice(patch_clip_x2_min, patch_clip_x2_max),
1026+
}
1027+
1028+
# Slice patchwise predictions
1029+
patch_clip = {
1030+
key: dataset.isel(**slicing_params)
1031+
for key, dataset in patch_pred.items()
1032+
}
1033+
1034+
patches_clipped.append(patch_clip)
10381035

10391036
# Create blank prediction object to stitch prediction values onto.
10401037
stitched_prediction = copy.deepcopy(patch_preds[0])
@@ -1054,9 +1051,12 @@ def stitch_clipped_predictions(
10541051
blank_ds[data_var][:] = np.nan
10551052
stitched_prediction[var_name] = blank_ds
10561053

1054+
# Restructure prediction objects for merging
1055+
restructured_patches = {key: [item[key] for item in patches_clipped] for key in patches_clipped[0].keys()}
1056+
10571057
# Merge patchwise predictions to create final stiched prediction.
10581058
# Iterate over each variable (key) in the prediction dictionary
1059-
for var_name, patches in patches_clipped.items():
1059+
for var_name, patches in restructured_patches.items():
10601060
# Retrieve the blank dataset for the current variable
10611061
prediction_array = stitched_prediction[var_name]
10621062

0 commit comments

Comments
 (0)