Skip to content

Commit 27e5b95

Browse files
committed
bf16 gradient clipping fix
bf16 checkpoint save/load
1 parent 1529313 commit 27e5b95

File tree

5 files changed

+81
-57
lines changed

5 files changed

+81
-57
lines changed

deepspeed/checkpoint/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111

1212
BASE_OPTIMIZER_STATE = 'base_optimizer_state'
1313
SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups"
14+
GROUPS_PADDING = 'groups_padding'
1415

1516
PARTITION_COUNT = 'partition_count'
1617
ZERO_STAGE = 'zero_stage'
18+
CLIP_GRAD = 'clip_gradient'
1719

1820
#########################################
1921
# Module checkpoint keys

deepspeed/runtime/bf16_optimizer.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import torch.distributed as dist
33
from deepspeed.runtime.constants import PIPE_REPLICATED
44
from deepspeed.ops.op_builder import UtilsBuilder
5+
from packaging import version as pkg_version
56

7+
from deepspeed.git_version_info import version
68
from deepspeed.runtime.utils import (get_global_norm_of_tensors,
79
clip_tensors_by_global_norm,
810
get_grad_norm,
@@ -13,6 +15,13 @@
1315
is_model_parallel_parameter,
1416
see_memory_usage)
1517

18+
from deepspeed.checkpoint.constants import (DS_VERSION,
19+
PARTITION_COUNT,
20+
BASE_OPTIMIZER_STATE,
21+
SINGLE_PARTITION_OF_FP32_GROUPS,
22+
CLIP_GRAD,
23+
GROUPS_PADDING)
24+
1625

1726
class BF16_Optimizer:
1827
def __init__(self,
@@ -36,6 +45,10 @@ def __init__(self,
3645
self.real_dp_process_group = [
3746
dp_process_group for i in range(len(self.optimizer.param_groups))
3847
]
48+
dp_world_size = dist.get_world_size(group=self.dp_process_group)
49+
self.partition_count = [
50+
dp_world_size for i in range(len(self.optimizer.param_groups))
51+
]
3952

4053
# Load pre-built or JIT compile (un)flatten ops
4154
util_ops = UtilsBuilder().load()
@@ -58,9 +71,9 @@ def __init__(self,
5871
self.fp32_groups_actual_gradients_flat = []
5972
self.fp32_groups_gradient_flat_partition = []
6073
self.fp32_groups_has_gradients = []
61-
self.step_count = 0
6274

63-
dp_world_size = dist.get_world_size(group=self.dp_process_group)
75+
self.step_count = 0
76+
self.groups_padding = []
6477

6578
for i, param_group in enumerate(self.optimizer.param_groups):
6679
see_memory_usage(f'before initializing group {i}', force=True)
@@ -127,6 +140,15 @@ def __init__(self,
127140
# track fp32 gradient updates
128141
self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i]))
129142

143+
# Record padding required for alignment
144+
if partition_id == dist.get_world_size(
145+
group=self.real_dp_process_group[i]) - 1:
146+
padding = self.bf16_groups_flat[i].numel() - length_without_padding
147+
else:
148+
padding = 0
149+
150+
self.groups_padding.append(padding)
151+
130152
# update optimizer param groups to reference fp32 params partition
131153
param_group['params'] = [self.fp32_groups_flat_partition[i]]
132154

@@ -186,8 +208,8 @@ def step(self, closure=None):
186208
if self.clip_grad > 0.:
187209
clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(),
188210
max_norm=self.clip_grad,
189-
mpu=self.mpu,
190-
global_grad_norm=all_groups_norm)
211+
global_norm=all_groups_norm,
212+
mpu=self.mpu)
191213

192214
self.optimizer.step()
193215

@@ -278,18 +300,47 @@ def clear_lp_grads(self):
278300
param.grad = None
279301

280302
def state_dict(self):
281-
# TODO capture all training state for checkpointing
282303
state_dict = {}
283-
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
284-
state_dict['clip_grad'] = self.clip_grad
304+
state_dict[CLIP_GRAD] = self.clip_grad
305+
state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict()
306+
state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = self.fp32_groups_flat_partition
307+
state_dict[GROUPS_PADDING] = self.groups_padding
308+
state_dict[PARTITION_COUNT] = self.partition_count
309+
state_dict[DS_VERSION] = version
310+
285311
return state_dict
286312

287-
def load_state_dict(self, state_dict, load_optimizer_states=True):
313+
def load_state_dict(self,
314+
state_dict_list,
315+
load_optimizer_states=True,
316+
load_from_fp32_weights=False):
317+
dp_rank = dist.get_rank(group=self.dp_process_group)
318+
current_rank_sd = state_dict_list[dp_rank]
319+
320+
ckpt_version = current_rank_sd.get(DS_VERSION, False)
321+
assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed"
322+
ckpt_version = pkg_version.parse(ckpt_version)
323+
324+
self.clip_grad = current_rank_sd[CLIP_GRAD]
325+
288326
if load_optimizer_states:
289-
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
290-
self.clip_grad = state_dict['clip_grad']
327+
self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])
328+
329+
if load_from_fp32_weights:
330+
for current, saved in zip(self.fp32_groups_flat_partition, current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):
331+
src_tensor = _get_padded_tensor(saved, current.numel())
332+
current.data.copy_(src_tensor.data)
291333

292334
@property
293335
def param_groups(self):
294336
"""Forward the wrapped optimizer's parameters."""
295337
return self.optimizer.param_groups
338+
339+
340+
def _get_padded_tensor(src_tensor, size):
341+
if src_tensor.numel() >= size:
342+
return src_tensor
343+
padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device)
344+
slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel())
345+
slice_tensor.data.copy_(src_tensor.data)
346+
return padded_tensor

deepspeed/runtime/engine.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ def _configure_checkpointing(self, dist_init_required):
807807
self.save_non_zero_checkpoint = (
808808
dp_rank == 0) or self.zero_optimization_partition_weights()
809809

810-
if self.zero_optimization():
810+
if self.zero_optimization() or self.bfloat16_enabled():
811811
param_rank = torch.distributed.get_rank(
812812
group=self.optimizer.dp_process_group)
813813

@@ -2370,7 +2370,8 @@ def load_module_state_dict(self, state_dict, strict=True):
23702370
self.module.load_state_dict(state_dict, strict=strict)
23712371

23722372
def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank):
2373-
filename = "zero_pp_rank_{}".format(dp_rank)
2373+
filename = "bf16_zero_pp_rank_{}".format(
2374+
dp_rank) if self.bfloat16_enabled() else "zero_pp_rank_{}".format(dp_rank)
23742375
zero_ckpt_name = os.path.join(
23752376
checkpoints_path,
23762377
str(tag),
@@ -2495,7 +2496,8 @@ def load_checkpoint(self,
24952496
load_lr_scheduler_states=load_lr_scheduler_states,
24962497
load_module_only=load_module_only)
24972498

2498-
if self.zero_optimization() and load_path is not None:
2499+
load_zero_checkpoint = self.zero_optimization() or self.bfloat16_enabled()
2500+
if load_zero_checkpoint and load_path is not None:
24992501
success = self._load_zero_checkpoint(
25002502
load_dir,
25012503
tag,
@@ -2567,8 +2569,9 @@ def _load_checkpoint(self,
25672569
else:
25682570
optim_checkpoint = checkpoint
25692571

2570-
if load_optimizer_states and self.optimizer is not None and not self.zero_optimization(
2571-
):
2572+
has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled(
2573+
)
2574+
if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state:
25722575
if self.fp16_enabled():
25732576
self.optimizer.load_state_dict(
25742577
optim_checkpoint['optimizer'],
@@ -2964,13 +2967,13 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):
29642967
# module_state_dict() and uses this path to save the model. module_state_dict()
29652968
# then instead just returns None.
29662969
self._curr_ckpt_path = os.path.join(save_dir, tag)
2967-
2970+
zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
29682971
state = dict(module=self.module_state_dict(),
29692972
buffer_names=self._get_buffer_names(),
29702973
optimizer=self.optimizer.state_dict()
2971-
if self.optimizer and not self.zero_optimization() else None,
2974+
if self.optimizer and not zero_optimizer_state else None,
29722975
param_shapes=self._get_zero_param_shapes()
2973-
if self.optimizer and self.zero_optimization() else None,
2976+
if self.optimizer and zero_optimizer_state else None,
29742977
lr_scheduler=self.lr_scheduler.state_dict()
29752978
if self.lr_scheduler is not None else None,
29762979
sparse_tensor_module_names=self.sparse_tensor_module_names,
@@ -3028,6 +3031,8 @@ def _get_zero_param_shapes(self):
30283031
# if we don't use it, we get parameters ordered incorrectly
30293032
if hasattr(self.optimizer, "round_robin_bit16_groups"):
30303033
bit16_groups = self.optimizer.round_robin_bit16_groups
3034+
elif self.bfloat16_enabled() and not self.zero_optimization():
3035+
bit16_groups = self.optimizer.bf16_groups
30313036
else:
30323037
bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage(
30333038
) == 2 else self.optimizer.fp16_groups
@@ -3068,7 +3073,8 @@ def _save_zero_checkpoint(self, save_path, tag):
30683073
torch.save(zero_sd, zero_checkpoint_name)
30693074
if self.global_rank == 0:
30703075
self._copy_recovery_script(save_path)
3071-
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))
3076+
ckpt_type = 'zero' if self.zero_optimization() else 'bfl6_zero'
3077+
logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}')
30723078

30733079
def _zero3_consolidated_16bit_state_dict(self):
30743080
"""

deepspeed/runtime/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ def clip_tensors_by_global_norm(input_tensors,
944944
"""Clip list of tensors by global norm.
945945
Args:
946946
input_tensors: List of tensors to be clipped
947-
global_grad_norm (float, optional): Precomputed norm. Defaults to None.
947+
global_norm (float, optional): Precomputed norm. Defaults to None.
948948
mpu (optional): model parallelism unit. Defaults to None.
949949
eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6
950950
Returns:
@@ -953,7 +953,7 @@ def clip_tensors_by_global_norm(input_tensors,
953953
if global_norm is None:
954954
global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu)
955955

956-
clip_coef = max_norm / (global_grad_norm + eps)
956+
clip_coef = max_norm / (global_norm + eps)
957957

958958
if clip_coef < 1:
959959
for t in input_tensors:

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from deepspeed.utils import logger
2424
from deepspeed.moe.utils import is_moe_param
2525
from deepspeed.git_version_info import version
26-
26+
from deepspeed.runtime.constants import PIPE_REPLICATED
2727
from deepspeed.checkpoint.constants import (DS_VERSION,
2828
PARTITION_COUNT,
2929
SINGLE_PARTITION_OF_FP32_GROUPS,
@@ -1747,41 +1747,6 @@ def step(self, closure=None):
17471747
start_alignment_factor=self.nccl_start_alignment_factor,
17481748
allgather_bucket_size=self.allgather_bucket_size)
17491749

1750-
# for group_id, partitioned_params in enumerate(self.parallel_partitioned_bit16_groups):
1751-
#
1752-
# # Sequential AllGather Best of both worlds
1753-
# dp_world_size = dist.get_world_size(
1754-
# group=self.real_dp_process_group[group_id])
1755-
# num_shards = max(
1756-
# 1,
1757-
# partitioned_params[partition_id].numel() * dp_world_size //
1758-
# self.allgather_bucket_size)
1759-
#
1760-
# shard_size = partitioned_params[partition_id].numel() // num_shards
1761-
#
1762-
# # Enforce nccl/rccl alignment of start location of each shard
1763-
# shard_size = shard_size - (shard_size % self.nccl_start_alignment_factor)
1764-
#
1765-
# num_elements = shard_size
1766-
#
1767-
# assert shard_size * num_shards <= partitioned_params[partition_id].numel()
1768-
#
1769-
# for shard_id in range(num_shards):
1770-
#
1771-
# if shard_id == (num_shards - 1):
1772-
# num_elements = partitioned_params[partition_id].numel(
1773-
# ) - shard_id * shard_size
1774-
#
1775-
# shard_list = []
1776-
# for dp_id in range(dp_world_size):
1777-
# curr_shard = partitioned_params[dp_id].narrow(
1778-
# 0,
1779-
# shard_id * shard_size,
1780-
# num_elements).detach()
1781-
# shard_list.append(curr_shard)
1782-
# dist.all_gather(shard_list,
1783-
# shard_list[partition_id],
1784-
# group=self.real_dp_process_group[group_id])
17851750
self.stop_timers([OPTIMIZER_ALLGATHER])
17861751

17871752
# TODO: we probably don't need this? just to be safe

0 commit comments

Comments
 (0)