@@ -818,7 +818,6 @@ def get_coordinate_extent(ds: Union[xr.DataArray, xr.Dataset], x1_ascend: bool,
818818 ----------
819819 ds : Data object
820820 The dataset or data array to determine coordinate extent for.
821- Refer to `X_t` in `predict_patchwise()` for supported types.
822821
823822 x1_ascend : bool
824823 Whether the x1 coordinates ascend (increase) from top to bottom.
@@ -937,104 +936,103 @@ def stitch_clipped_predictions(
937936 data_x1_index , data_x2_index = get_index (data_x1_coords , data_x2_coords )
938937
939938 # Iterate through patchwise predictions and slice edges prior to stitchin.
940- patches_clipped = { var_name : [] for var_name in patch_preds [ 0 ]. keys ()}
939+ patches_clipped = []
941940 for i , patch_pred in enumerate (patch_preds ):
942- for var_name , data_array in patch_pred .items ():
943- if var_name in patch_pred :
944-
945- # Get row/col index values of each patch.
946- patch_x1_coords , patch_x2_coords = get_coordinate_extent (data_array , x1_ascend , x2_ascend )
947- patch_x1_index , patch_x2_index = get_index (patch_x1_coords , patch_x2_coords )
948-
949- # Calculate size of border to slice of each edge of patchwise predictions.
950- # Initially set the size of all borders to the size of the overlap.
951- b_x1_min , b_x1_max = patch_overlap [0 ], patch_overlap [0 ]
952- b_x2_min , b_x2_max = patch_overlap [1 ], patch_overlap [1 ]
953-
954- # Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column.
955- if patch_x2_index [0 ] == data_x2_index [0 ]:
956- b_x2_min = 0
957- b_x2_max = b_x2_max
958-
959- # At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch.
960- elif patch_x2_index [1 ] == data_x2_index [1 ]:
961- b_x2_max = 0
962- patch_row_prev = preds [i - 1 ]
963-
964- # If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels.
965- # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
966- if x2_ascend :
967- prev_patch_x2_max = get_index (
968- patch_row_prev [var_name ].coords [orig_x2_name ].max (),
969- x1 = False ,
970- )
971- b_x2_min = (
972- prev_patch_x2_max - patch_x2_index [0 ]
973- ) - patch_overlap [1 ]
974-
975- # If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
976- # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
977- else :
978- prev_patch_x2_min = get_index (
979- patch_row_prev [var_name ].coords [orig_x2_name ].min (),
980- x1 = False ,
981- )
982- b_x2_min = (
983- patch_x2_index [0 ] - prev_patch_x2_min
984- ) - patch_overlap [1 ]
985- else :
986- b_x2_max = b_x2_max
987-
988- # Repeat process as above for x1 coordinates.
989- if patch_x1_index [0 ] == data_x1_index [0 ]:
990- b_x1_min = 0
991-
992- elif abs (patch_x1_index [1 ] - data_x1_index [1 ]) < 2 :
993- b_x1_max = 0
994- b_x1_max = b_x1_max
995- patch_prev = preds [i - patches_per_row ]
996- if x1_ascend :
997- prev_patch_x1_max = get_index (
998- patch_prev [var_name ].coords [orig_x1_name ].max (),
999- x1 = True ,
1000- )
1001- b_x1_min = (
1002- prev_patch_x1_max - patch_x1_index [0 ]
1003- ) - patch_overlap [0 ]
1004- else :
1005- prev_patch_x1_min = get_index (
1006- patch_prev [var_name ].coords [orig_x1_name ].min (),
1007- x1 = True ,
1008- )
1009-
1010- b_x1_min = (
1011- prev_patch_x1_min - patch_x1_index [0 ]
1012- ) - patch_overlap [0 ]
1013- else :
1014- b_x1_max = b_x1_max
941+ # get one variable name to use for coordinates and extent
942+ first_key = list (patch_pred .keys ())[0 ]
943+ # Get row/col index values of each patch.
944+ patch_x1_coords , patch_x2_coords = get_coordinate_extent (patch_pred [first_key ], x1_ascend , x2_ascend )
945+ patch_x1_index , patch_x2_index = get_index (patch_x1_coords , patch_x2_coords )
946+
947+ # Calculate size of border to slice of each edge of patchwise predictions.
948+ # Initially set the size of all borders to the size of the overlap.
949+ b_x1_min , b_x1_max = patch_overlap [0 ], patch_overlap [0 ]
950+ b_x2_min , b_x2_max = patch_overlap [1 ], patch_overlap [1 ]
951+
952+ # Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column.
953+ if patch_x2_index [0 ] == data_x2_index [0 ]:
954+ b_x2_min = 0
955+ b_x2_max = b_x2_max
956+
957+ # At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch.
958+ elif patch_x2_index [1 ] == data_x2_index [1 ]:
959+ b_x2_max = 0
960+ patch_row_prev = patch_preds [i - 1 ]
961+
962+ # If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels.
963+ # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
964+ if x2_ascend :
965+ prev_patch_x2_max = get_index (
966+ patch_row_prev [first_key ].coords [orig_x2_name ].max (),
967+ x1 = False ,
968+ )
969+ b_x2_min = (
970+ prev_patch_x2_max - patch_x2_index [0 ]
971+ ) - patch_overlap [1 ]
1015972
1016- patch_clip_x1_min = int (b_x1_min )
1017- patch_clip_x1_max = int (
1018- data_array .sizes [orig_x1_name ] - b_x1_max
973+ # If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
974+ # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
975+ else :
976+ prev_patch_x2_min = get_index (
977+ patch_row_prev [first_key ].coords [orig_x2_name ].min (),
978+ x1 = False ,
1019979 )
1020- patch_clip_x2_min = int (b_x2_min )
1021- patch_clip_x2_max = int (
1022- data_array .sizes [orig_x2_name ] - b_x2_max
980+ b_x2_min = (
981+ patch_x2_index [0 ] - prev_patch_x2_min
982+ ) - patch_overlap [1 ]
983+ else :
984+ b_x2_max = b_x2_max
985+
986+ # Repeat process as above for x1 coordinates.
987+ if patch_x1_index [0 ] == data_x1_index [0 ]:
988+ b_x1_min = 0
989+
990+ elif abs (patch_x1_index [1 ] - data_x1_index [1 ]) < 2 :
991+ b_x1_max = 0
992+ b_x1_max = b_x1_max
993+ patch_prev = patch_preds [i - patches_per_row ]
994+ if x1_ascend :
995+ prev_patch_x1_max = get_index (
996+ patch_prev [first_key ].coords [orig_x1_name ].max (),
997+ x1 = True ,
1023998 )
1024-
1025- # Slice patchwise predictions
1026- patch_clip = data_array .isel (
1027- ** {
1028- orig_x1_name : slice (
1029- patch_clip_x1_min , patch_clip_x1_max
1030- ),
1031- orig_x2_name : slice (
1032- patch_clip_x2_min , patch_clip_x2_max
1033- ),
1034- }
999+ b_x1_min = (
1000+ prev_patch_x1_max - patch_x1_index [0 ]
1001+ ) - patch_overlap [0 ]
1002+ else :
1003+ prev_patch_x1_min = get_index (
1004+ patch_prev [first_key ].coords [orig_x1_name ].min (),
1005+ x1 = True ,
10351006 )
10361007
1037- patches_clipped [var_name ].append (patch_clip )
1008+ b_x1_min = (
1009+ prev_patch_x1_min - patch_x1_index [0 ]
1010+ ) - patch_overlap [0 ]
1011+ else :
1012+ b_x1_max = b_x1_max
1013+
1014+ patch_clip_x1_min = int (b_x1_min )
1015+ patch_clip_x1_max = int (
1016+ patch_pred [first_key ].sizes [orig_x1_name ] - b_x1_max
1017+ )
1018+ patch_clip_x2_min = int (b_x2_min )
1019+ patch_clip_x2_max = int (
1020+ patch_pred [first_key ].sizes [orig_x2_name ] - b_x2_max
1021+ )
1022+
1023+ # Define slicing parameters
1024+ slicing_params = {
1025+ orig_x1_name : slice (patch_clip_x1_min , patch_clip_x1_max ),
1026+ orig_x2_name : slice (patch_clip_x2_min , patch_clip_x2_max ),
1027+ }
1028+
1029+ # Slice patchwise predictions
1030+ patch_clip = {
1031+ key : dataset .isel (** slicing_params )
1032+ for key , dataset in patch_pred .items ()
1033+ }
1034+
1035+ patches_clipped .append (patch_clip )
10381036
10391037 # Create blank prediction object to stitch prediction values onto.
10401038 stitched_prediction = copy .deepcopy (patch_preds [0 ])
@@ -1054,9 +1052,12 @@ def stitch_clipped_predictions(
10541052 blank_ds [data_var ][:] = np .nan
10551053 stitched_prediction [var_name ] = blank_ds
10561054
1055+ # Restructure prediction objects for merging
1056+ restructured_patches = {key : [item [key ] for item in patches_clipped ] for key in patches_clipped [0 ].keys ()}
1057+
10571058 # Merge patchwise predictions to create final stiched prediction.
10581059 # Iterate over each variable (key) in the prediction dictionary
1059- for var_name , patches in patches_clipped .items ():
1060+ for var_name , patches in restructured_patches .items ():
10601061 # Retrieve the blank dataset for the current variable
10611062 prediction_array = stitched_prediction [var_name ]
10621063
0 commit comments