@@ -818,7 +818,6 @@ def get_coordinate_extent(ds, x1_ascend, x2_ascend) -> tuple:
818818 ----------
819819 ds : Data object
820820 The dataset or data array to determine coordinate extent for.
821- Refer to `X_t` in `predict_patchwise()` for supported types.
822821
823822 x1_ascend : bool
824823 Whether the x1 coordinates ascend (increase) from top to bottom.
@@ -937,104 +936,102 @@ def stitch_clipped_predictions(
937936 data_x1_index , data_x2_index = get_index (data_x1_coords , data_x2_coords )
938937
939938 # Iterate through patchwise predictions and slice edges prior to stitchin.
940- patches_clipped = { var_name : [] for var_name in patch_preds [ 0 ]. keys ()}
939+ patches_clipped = []
941940 for i , patch_pred in enumerate (patch_preds ):
942- for var_name , data_array in patch_pred .items ():
943- if var_name in patch_pred :
944-
945- # Get row/col index values of each patch.
946- patch_x1_coords , patch_x2_coords = get_coordinate_extent (data_array , x1_ascend , x2_ascend )
947- patch_x1_index , patch_x2_index = get_index (patch_x1_coords , patch_x2_coords )
948-
949- # Calculate size of border to slice of each edge of patchwise predictions.
950- # Initially set the size of all borders to the size of the overlap.
951- b_x1_min , b_x1_max = patch_overlap [0 ], patch_overlap [0 ]
952- b_x2_min , b_x2_max = patch_overlap [1 ], patch_overlap [1 ]
953-
954- # Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column.
955- if patch_x2_index [0 ] == data_x2_index [0 ]:
956- b_x2_min = 0
957- b_x2_max = b_x2_max
958-
959- # At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch.
960- elif patch_x2_index [1 ] == data_x2_index [1 ]:
961- b_x2_max = 0
962- patch_row_prev = preds [i - 1 ]
963-
964- # If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels.
965- # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
966- if x2_ascend :
967- prev_patch_x2_max = get_index (
968- patch_row_prev [var_name ].coords [orig_x2_name ].max (),
969- x1 = False ,
970- )
971- b_x2_min = (
972- prev_patch_x2_max - patch_x2_index [0 ]
973- ) - patch_overlap [1 ]
974-
975- # If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
976- # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
977- else :
978- prev_patch_x2_min = get_index (
979- patch_row_prev [var_name ].coords [orig_x2_name ].min (),
980- x1 = False ,
981- )
982- b_x2_min = (
983- patch_x2_index [0 ] - prev_patch_x2_min
984- ) - patch_overlap [1 ]
985- else :
986- b_x2_max = b_x2_max
987-
988- # Repeat process as above for x1 coordinates.
989- if patch_x1_index [0 ] == data_x1_index [0 ]:
990- b_x1_min = 0
991-
992- elif abs (patch_x1_index [1 ] - data_x1_index [1 ]) < 2 :
993- b_x1_max = 0
994- b_x1_max = b_x1_max
995- patch_prev = preds [i - patches_per_row ]
996- if x1_ascend :
997- prev_patch_x1_max = get_index (
998- patch_prev [var_name ].coords [orig_x1_name ].max (),
999- x1 = True ,
1000- )
1001- b_x1_min = (
1002- prev_patch_x1_max - patch_x1_index [0 ]
1003- ) - patch_overlap [0 ]
1004- else :
1005- prev_patch_x1_min = get_index (
1006- patch_prev [var_name ].coords [orig_x1_name ].min (),
1007- x1 = True ,
1008- )
1009-
1010- b_x1_min = (
1011- prev_patch_x1_min - patch_x1_index [0 ]
1012- ) - patch_overlap [0 ]
1013- else :
1014- b_x1_max = b_x1_max
941+ first_key , first_value = next (iter (patch_pred .items ()))
942+ # Get row/col index values of each patch.
943+ patch_x1_coords , patch_x2_coords = get_coordinate_extent (patch_pred [first_key ], x1_ascend , x2_ascend )
944+ patch_x1_index , patch_x2_index = get_index (patch_x1_coords , patch_x2_coords )
945+
946+ # Calculate size of border to slice of each edge of patchwise predictions.
947+ # Initially set the size of all borders to the size of the overlap.
948+ b_x1_min , b_x1_max = patch_overlap [0 ], patch_overlap [0 ]
949+ b_x2_min , b_x2_max = patch_overlap [1 ], patch_overlap [1 ]
950+
951+ # Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column.
952+ if patch_x2_index [0 ] == data_x2_index [0 ]:
953+ b_x2_min = 0
954+ b_x2_max = b_x2_max
955+
956+ # At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch.
957+ elif patch_x2_index [1 ] == data_x2_index [1 ]:
958+ b_x2_max = 0
959+ patch_row_prev = patch_preds [i - 1 ]
960+
961+ # If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels.
962+ # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
963+ if x2_ascend :
964+ prev_patch_x2_max = get_index (
965+ patch_row_prev [first_key ].coords [orig_x2_name ].max (),
966+ x1 = False ,
967+ )
968+ b_x2_min = (
969+ prev_patch_x2_max - patch_x2_index [0 ]
970+ ) - patch_overlap [1 ]
1015971
1016- patch_clip_x1_min = int (b_x1_min )
1017- patch_clip_x1_max = int (
1018- data_array .sizes [orig_x1_name ] - b_x1_max
972+ # If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
973+ # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
974+ else :
975+ prev_patch_x2_min = get_index (
976+ patch_row_prev [first_key ].coords [orig_x2_name ].min (),
977+ x1 = False ,
1019978 )
1020- patch_clip_x2_min = int (b_x2_min )
1021- patch_clip_x2_max = int (
1022- data_array .sizes [orig_x2_name ] - b_x2_max
979+ b_x2_min = (
980+ patch_x2_index [0 ] - prev_patch_x2_min
981+ ) - patch_overlap [1 ]
982+ else :
983+ b_x2_max = b_x2_max
984+
985+ # Repeat process as above for x1 coordinates.
986+ if patch_x1_index [0 ] == data_x1_index [0 ]:
987+ b_x1_min = 0
988+
989+ elif abs (patch_x1_index [1 ] - data_x1_index [1 ]) < 2 :
990+ b_x1_max = 0
991+ b_x1_max = b_x1_max
992+ patch_prev = patch_preds [i - patches_per_row ]
993+ if x1_ascend :
994+ prev_patch_x1_max = get_index (
995+ patch_prev [first_key ].coords [orig_x1_name ].max (),
996+ x1 = True ,
1023997 )
1024-
1025- # Slice patchwise predictions
1026- patch_clip = data_array .isel (
1027- ** {
1028- orig_x1_name : slice (
1029- patch_clip_x1_min , patch_clip_x1_max
1030- ),
1031- orig_x2_name : slice (
1032- patch_clip_x2_min , patch_clip_x2_max
1033- ),
1034- }
998+ b_x1_min = (
999+ prev_patch_x1_max - patch_x1_index [0 ]
1000+ ) - patch_overlap [0 ]
1001+ else :
1002+ prev_patch_x1_min = get_index (
1003+ patch_prev [first_key ].coords [orig_x1_name ].min (),
1004+ x1 = True ,
10351005 )
10361006
1037- patches_clipped [var_name ].append (patch_clip )
1007+ b_x1_min = (
1008+ prev_patch_x1_min - patch_x1_index [0 ]
1009+ ) - patch_overlap [0 ]
1010+ else :
1011+ b_x1_max = b_x1_max
1012+
1013+ patch_clip_x1_min = int (b_x1_min )
1014+ patch_clip_x1_max = int (
1015+ patch_pred [first_key ].sizes [orig_x1_name ] - b_x1_max
1016+ )
1017+ patch_clip_x2_min = int (b_x2_min )
1018+ patch_clip_x2_max = int (
1019+ patch_pred [first_key ].sizes [orig_x2_name ] - b_x2_max
1020+ )
1021+
1022+ # Define slicing parameters
1023+ slicing_params = {
1024+ orig_x1_name : slice (patch_clip_x1_min , patch_clip_x1_max ),
1025+ orig_x2_name : slice (patch_clip_x2_min , patch_clip_x2_max ),
1026+ }
1027+
1028+ # Slice patchwise predictions
1029+ patch_clip = {
1030+ key : dataset .isel (** slicing_params )
1031+ for key , dataset in patch_pred .items ()
1032+ }
1033+
1034+ patches_clipped .append (patch_clip )
10381035
10391036 # Create blank prediction object to stitch prediction values onto.
10401037 stitched_prediction = copy .deepcopy (patch_preds [0 ])
@@ -1054,9 +1051,12 @@ def stitch_clipped_predictions(
10541051 blank_ds [data_var ][:] = np .nan
10551052 stitched_prediction [var_name ] = blank_ds
10561053
1054+ # Restructure prediction objects for merging
1055+ restructured_patches = {key : [item [key ] for item in patches_clipped ] for key in patches_clipped [0 ].keys ()}
1056+
10571057 # Merge patchwise predictions to create final stiched prediction.
10581058 # Iterate over each variable (key) in the prediction dictionary
1059- for var_name , patches in patches_clipped .items ():
1059+ for var_name , patches in restructured_patches .items ():
10601060 # Retrieve the blank dataset for the current variable
10611061 prediction_array = stitched_prediction [var_name ]
10621062
0 commit comments