22import torch .distributed as dist
33from deepspeed .runtime .constants import PIPE_REPLICATED
44from deepspeed .ops .op_builder import UtilsBuilder
5+ from packaging import version as pkg_version
56
7+ from deepspeed .git_version_info import version
68from deepspeed .runtime .utils import (get_global_norm_of_tensors ,
79 clip_tensors_by_global_norm ,
810 get_grad_norm ,
1315 is_model_parallel_parameter ,
1416 see_memory_usage )
1517
18+ from deepspeed .checkpoint .constants import (DS_VERSION ,
19+ PARTITION_COUNT ,
20+ BASE_OPTIMIZER_STATE ,
21+ SINGLE_PARTITION_OF_FP32_GROUPS ,
22+ CLIP_GRAD ,
23+ GROUPS_PADDING )
24+
1625
1726class BF16_Optimizer :
1827 def __init__ (self ,
@@ -36,6 +45,10 @@ def __init__(self,
3645 self .real_dp_process_group = [
3746 dp_process_group for i in range (len (self .optimizer .param_groups ))
3847 ]
48+ dp_world_size = dist .get_world_size (group = self .dp_process_group )
49+ self .partition_count = [
50+ dp_world_size for i in range (len (self .optimizer .param_groups ))
51+ ]
3952
4053 # Load pre-built or JIT compile (un)flatten ops
4154 util_ops = UtilsBuilder ().load ()
@@ -58,9 +71,9 @@ def __init__(self,
5871 self .fp32_groups_actual_gradients_flat = []
5972 self .fp32_groups_gradient_flat_partition = []
6073 self .fp32_groups_has_gradients = []
61- self .step_count = 0
6274
63- dp_world_size = dist .get_world_size (group = self .dp_process_group )
75+ self .step_count = 0
76+ self .groups_padding = []
6477
6578 for i , param_group in enumerate (self .optimizer .param_groups ):
6679 see_memory_usage (f'before initializing group { i } ' , force = True )
@@ -127,6 +140,15 @@ def __init__(self,
127140 # track fp32 gradient updates
128141 self .fp32_groups_has_gradients .append ([False ] * len (self .bf16_groups [i ]))
129142
143+ # Record padding required for alignment
144+ if partition_id == dist .get_world_size (
145+ group = self .real_dp_process_group [i ]) - 1 :
146+ padding = self .bf16_groups_flat [i ].numel () - length_without_padding
147+ else :
148+ padding = 0
149+
150+ self .groups_padding .append (padding )
151+
130152 # update optimizer param groups to reference fp32 params partition
131153 param_group ['params' ] = [self .fp32_groups_flat_partition [i ]]
132154
@@ -186,8 +208,8 @@ def step(self, closure=None):
186208 if self .clip_grad > 0. :
187209 clip_tensors_by_global_norm (input_tensors = self .get_grads_for_norm (),
188210 max_norm = self .clip_grad ,
189- mpu = self . mpu ,
190- global_grad_norm = all_groups_norm )
211+ global_norm = all_groups_norm ,
212+ mpu = self . mpu )
191213
192214 self .optimizer .step ()
193215
@@ -278,18 +300,47 @@ def clear_lp_grads(self):
278300 param .grad = None
279301
280302 def state_dict (self ):
281- # TODO capture all training state for checkpointing
282303 state_dict = {}
283- state_dict ['optimizer_state_dict' ] = self .optimizer .state_dict ()
284- state_dict ['clip_grad' ] = self .clip_grad
304+ state_dict [CLIP_GRAD ] = self .clip_grad
305+ state_dict [BASE_OPTIMIZER_STATE ] = self .optimizer .state_dict ()
306+ state_dict [SINGLE_PARTITION_OF_FP32_GROUPS ] = self .fp32_groups_flat_partition
307+ state_dict [GROUPS_PADDING ] = self .groups_padding
308+ state_dict [PARTITION_COUNT ] = self .partition_count
309+ state_dict [DS_VERSION ] = version
310+
285311 return state_dict
286312
287- def load_state_dict (self , state_dict , load_optimizer_states = True ):
313+ def load_state_dict (self ,
314+ state_dict_list ,
315+ load_optimizer_states = True ,
316+ load_from_fp32_weights = False ):
317+ dp_rank = dist .get_rank (group = self .dp_process_group )
318+ current_rank_sd = state_dict_list [dp_rank ]
319+
320+ ckpt_version = current_rank_sd .get (DS_VERSION , False )
321+ assert ckpt_version , f"Empty ds_version in checkpoint, not clear how to proceed"
322+ ckpt_version = pkg_version .parse (ckpt_version )
323+
324+ self .clip_grad = current_rank_sd [CLIP_GRAD ]
325+
288326 if load_optimizer_states :
289- self .optimizer .load_state_dict (state_dict ['optimizer_state_dict' ])
290- self .clip_grad = state_dict ['clip_grad' ]
327+ self .optimizer .load_state_dict (current_rank_sd [BASE_OPTIMIZER_STATE ])
328+
329+ if load_from_fp32_weights :
330+ for current , saved in zip (self .fp32_groups_flat_partition , current_rank_sd [SINGLE_PARTITION_OF_FP32_GROUPS ]):
331+ src_tensor = _get_padded_tensor (saved , current .numel ())
332+ current .data .copy_ (src_tensor .data )
291333
292334 @property
293335 def param_groups (self ):
294336 """Forward the wrapped optimizer's parameters."""
295337 return self .optimizer .param_groups
338+
339+
340+ def _get_padded_tensor (src_tensor , size ):
341+ if src_tensor .numel () >= size :
342+ return src_tensor
343+ padded_tensor = torch .zeros (size , dtype = src_tensor .dtype , device = src_tensor .device )
344+ slice_tensor = torch .narrow (padded_tensor , 0 , 0 , src_tensor .numel ())
345+ slice_tensor .data .copy_ (src_tensor .data )
346+ return padded_tensor
0 commit comments