diff --git a/recipes/full_dpo_distributed.py b/recipes/full_dpo_distributed.py index 08400067d1..c7e10708eb 100644 --- a/recipes/full_dpo_distributed.py +++ b/recipes/full_dpo_distributed.py @@ -177,6 +177,11 @@ def __init__(self, cfg: DictConfig) -> None: "Gradient accumulation is not supported with optimizer in bwd." "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." ) + if self.fsdp_cpu_offload: + raise RuntimeError( + "CPU offload is not supported with optimizer in bwd atm." + "Please set fsdp_cpu_offload=False, or optimizer_in_bwd=False." + ) # activation checkpointing/offloading self._enable_activation_checkpointing = cfg.get( diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index fc7254b3c4..3423390070 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -226,6 +226,11 @@ def __init__(self, cfg: DictConfig) -> None: "Gradient accumulation is not supported with optimizer in bwd." "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." ) + if self.fsdp_cpu_offload: + raise RuntimeError( + "CPU offload is not supported with optimizer in bwd atm." + "Please set fsdp_cpu_offload=False, or optimizer_in_bwd=False." + ) # activation checkpointing/offloading self._enable_activation_checkpointing = cfg.get( diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index 1f6b91f163..0e8d49905d 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -232,6 +232,11 @@ def __init__(self, cfg: DictConfig) -> None: "Gradient accumulation is not supported with optimizer in bwd." "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." ) + if self.fsdp_cpu_offload: + raise RuntimeError( + "CPU offload is not supported with optimizer in bwd atm." + "Please set fsdp_cpu_offload=False, or optimizer_in_bwd=False." + ) self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False) self._checkpoint_client = CheckpointClient(cfg)