From 901723f58ba5cbd8201a3fbbcc21550857159ff9 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 4 Jun 2025 12:01:04 -0700 Subject: [PATCH 1/3] rfc --- torchtune/datasets/rfc_iterable_dataset.md | 404 +++++++++++++++++++++ 1 file changed, 404 insertions(+) create mode 100644 torchtune/datasets/rfc_iterable_dataset.md diff --git a/torchtune/datasets/rfc_iterable_dataset.md b/torchtune/datasets/rfc_iterable_dataset.md new file mode 100644 index 0000000000..a8efd553e6 --- /dev/null +++ b/torchtune/datasets/rfc_iterable_dataset.md @@ -0,0 +1,404 @@ +### Core issues: + 1) No support for iterative dataset: + - Dataset has to be fully loaded in memory; + - With map-style, no control over multi-sample operations, e.g. packing or skipping + - map-style is slower + - no support for streaming + 2) No support for weighted dataset: + - We have it in a single newly added dev recipe/config, but API needs polishing; + - We also support ConcatDataset, but its map style and there is no weighting; + 3) No support for on-the-fly data packing: It's done before training, taking a long time for large datasets; + +### UX issues: + 4) Unclear boundaries between HF and torchtune args + + ```python + def alpaca_dataset( + # --- torchtune specific args --- + tokenizer: ModelTokenizer, + train_on_input: bool = True, + packed: bool = False, + # --- HF loading args --- + source: str = "tatsu-lab/alpaca", + column_map: Optional[Dict[str, str]] = None, + split: str = "train", + **load_dataset_kwargs: Dict[str, Any], + # --- HF dataset method --- + filter_fn: Optional[Callable] = None, + ) -> Union[SFTDataset, PackedDataset]: + ``` + + 5) Lack of dataloader args: args are scattered in the config. Important args are not exposed, e.g. num_workers, pin_memory, etc. + ```yaml + dataset: + _component_: torchtune.datasets.multimodal.the_cauldron_dataset + seed: null + batch_size: 8 + shuffle: True + collate_fn: torchtune.data.padded_collate_tiled_images_and_mask + ``` + + 6) Different datasets have different arguments, because their message transforms are different. + +### Principles: + - Common API signatures for all datasets + - Offload what we can to hf datasets methods directly + - Less protagonism from our functions. E.g. config manipulations, instantiation, etc. (not the focus of this diff) + +### Proposal: + +# config.yaml +```yaml + +########### +# tokenizer +########### +tokenizer: + _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform + path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model + image_size: 560 + max_seq_len: 8192 + +########## +# dataloader +# consolidate all dataloader args here, which are currently scattered +########## +dataloader: + _component_: torchdata.stateful_dataloader.StatefulDataLoader + batch_size: 4 + num_workers: 4 + pin_memory: true + collate_fn: torchtune.data.padded_collate + + +#-------------------------------- +######### +# dataset if the class is used directly, as in our current SFTDataset +######### +dataset: + - _component_: torchtune.datasets.HfIterableDataset + load_args: + path: "tatsu-lab/alpaca" + split: "train" + message_transform: + _component_: torchtune.datasets.alpaca_message_transform + masking_strategy: "output_only" + column_map: + input: "prompt" + output: "response" + system_prompt: "foo" + filter_args: + function: torchtune.datasets.filter_fn_even_indices + with_indices: True + weight: 0.8 + - _component_: torchtune.datasets.HfIterableDataset + load_args: + path: "tatsu-lab/gsm8k" + split: "train" + message_transform: + _component_: torchtune.datasets.gsm8k_message_transform + masking_strategy: "output_only" + column_map: + input: "prompt" + output: "response" + system_prompt: "bar" + weight: 0.2 + +######### +# OR with builders +# TODO: test indexing "`tune run config – dataset[0].load_arg.split=train`" +######### +dataset: + - _component_: torchtune.datasets.build_alpaca_dataset + load_args: + split: "valid" + weight: 0.8 + - _component_: torchtune.datasets.build_gsm8k_dataset + message_transform: + system_prompt: "bar" + weight: 0.2 + +######### +# OR for a single dataset +######### +dataset: + _component_: torchtune.datasets.build_alpaca_dataset +#-------------------------------- + +######### +# Place for args common for all datasets that will be passed to the dataset constructor +# useful for multidataset. Used as cfg = dataset_defaults.update(dataset_cfg) +######### +dataset_defaults: + shuffle_buffer_size: 1000 + num_shards_per_worker: 16 + seed: ${seed} + tokenizer: ${tokenizer} + recipe_transform: + _component_: torchtune.datasets.SFTTransform + +######### +# args used in the dataset setup. This is not dataset specific. +######### +dataset_setup: + packing: + _component_: torchtune.datasets.packing.SFTPacking + max_seq_len: ${tokenizer.max_seq_len} + multidataset_stopping_strategy: "first_exhausted" # "all_exhausted" +``` + +# Builder example: torchtune/datasets/alpaca_dataset.py + +```python +def alpaca_dataset( + *, + load_args: optional[dict], + message_transform: optional[callable|dict], + tokenizer: ModelTokenizer, + recipe_transform: callable, + *args, **kwargs + ): + _load_args = { + source: "tatsu-lab/alpaca", + split: str = "train" + } + _message_transform_args = { + "train_on_input":False, + "column_map"={"input": "prompt", "output": "response"} + } + + # unify args + if load_args: + _load_args.update(**load_args) + + # unify args + if not message_transform or isinstance(message_transform, dict): + # remove component key, since we are using alpaca_message_transform as default + message_transform.pop("_component_", None) + + # instantiate the message transform + _message_transform_args.update(message_transform) + message_transform = alpaca_message_transform(**_message_transform_args) + + return HfIterableDataset(load_args, message_transform, tokenizer, recipe_transform, *args, **kwargs) +``` + +# Iterable dataset: Shared for all datasets and recipes (SFT, DPO, etc). Differences are in the transforms. +Location: torchtune/datasets/hf_iterable_dataset.py + +```python +class HfIterableDataset(IterableDataset, Stateful): + def __init__( + self, + *. + load_args: Dict, + message_transform: Callable, + tokenizer: Callable, + recipe_transform: Callable, + shuffle_buffer_size:Optional[int] = 1000, + seed:Optional[int] = 42 + num_shards_per_worker: int = 16, + weight:float = 1.0, + filter_args: Optional[Dict] = None, + *args, **kwargs + ): + """Initialize a single dataset with its specific transformations.""" + self.weight = weight + + world_size = 1 + if torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + + #TODO: Maybe # shards should be based on dataset size, if we know it + num_shards = world_size * num_shards_per_worker + ds = load_dataset(**load_args) + ds = ds.to_iterable_dataset(num_shards) + + if filter_args: + function = filter_args.get("function", None) + if function and not isinstance(function, Callable): + raise ValueError(f"filter_args['function'] must be a callable. Found {type(function)}") + # https://huggingface.co/docs/datasets/v3.6.0/en/stream#filter + ds = ds.filter(**filter_args) + + def _apply_transforms(sample): + sample = message_transform(sample) + sample = tokenizer(sample) + return recipe_transform(sample) + + ds = ds.map(_apply_transforms) #lazy + + if shuffle_buffer_size and shuffle_buffer_size > 0: + ds = ds.shuffle(shuffle_buffer_size, seed) + + # distribute + if world_size>1: + ds = split_dataset_by_node( + ds, + rank=torch.distributed.get_rank(), + world_size=world_size, + ) + + self.ds = ds + + def __iter__(self): + # Expose the for loop so extra logic can be added here, e.g. drop if no trainable tokens + # TODO: should we add try/except to handle/logerrors? + for sample in self.ds: + yield sample + + def state_dict(self): + state_dict = self.ds.state_dict() + state_dict["weight"] = self.weight + return state_dict + + def load_state_dict(self, state_dict): + self.weight = state_dict.pop("weight") + self.ds.load_state_dict(state_dict) +``` + +# Setup Data +Method in recipes/full_distributed.py +OR utility used in the recipe + +```python +from datasets import interleave_datasets, split_dataset_by_node +from torchtune.models.tokenizers import ModelTokenizer +import torch + +#NOTE: I have mixed feelings about passing multiple configDict to setup_data. This feels hard for the user to know what they should contain. On the other hand, i) setup_data doesnt need to make assumptions about the configs ii) we already do it currently. Alternative: use dataclassses + +def setup_data( + dataset_cfg: ConfigDict, + dataset_defaults: ConfigDict, + data_setup_cfg: ConfigDict, + dataloader_cfg: ConfigDict, + seed: int, + pad_idx: int, + ignore_idx: int, + pad_to_multiple_of: int, + ) -> "IterableDataset": + """ + Equivalent to setup_data in the recipe + """ + iterable_datasets = [] + weights = [] + dataset_defaults = {} if dataset_defaults is None else dataset_defaults + + # add to a list just for processing + if not isinstance(dataset_cfg, list): + dataset_cfg = [dataset_cfg] + + for base_cfg in dataset_cfg: + weight = base_cfg.get("weight", 1.0) + weights.append(weight) + + base_cfg = OmegaConf.merge(dataset_defaults, base_cfg) + ds = instantiate(base_cfg) + iterable_datasets.append(ds) + + + # interleave for multidataset + if len(iterable_datasets) > 1: + weights = normalize_weights(weights) # sum to 1 + ds = interleave_datasets( + iterable_datasets, + probabilities=weights, + seed=seed, + # strategies: https://huggingface.co/docs/datasets/v3.3.2/en/package_reference/main_classes#datasets.interleave_datasets.stopping_strategy + stopping_strategy=data_setup_cfg.multidataset_stopping_strategy, + ) + else: + ds = iterable_datasets[0] + + # FIXME: remove from config + if setup_cfg.packing: + # Subclass of IterableDataset, takes any iterator as input + ds = instantiate(data_setup_cfg.packing, + dataset=ds, + padding_idx=pad_id, #TODO: in the future, move padding to collate_fn + ) + + # Instantiate collate_fn + collate_fn = dataloader_cfg.pop("collate_fn", None) + #TODO: in the future, unify those two + if collate_fn is None: + collate_fn = "torchtune.data.padded_collate_packed" if packing else "torchtune.data.padded_collate_sft" + + collate_fn = _get_component_from_path(collate_fn) + collate_fn = partial(collate_fn, + padding_idx=pad_idx, + ignore_idx=ignore_id, + pad_to_multiple_of=pad_to_multiple_of + ) + + # dropping last avoids shape issues with compile + flex attention + if "drop_last" not in dataloader_cfg: + dataloader_cfg["drop_last"] = True + + dataloader = instantiate(dataloader_cfg, dataset=ds, collate_fn=collate_fn) + + return dataloader +``` + +# Recipe train loop +```python +for epoch in range(n_epochs): + my_iterable_dataset.set_epoch(epoch) + for example in my_iterable_dataset: # fast + reshuffled at each epoch using `effective_seed = seed + epoch` + pass +``` + +### Backward compatibility + +Options: +1. Make setup_data an utility, and have two utilities supporting the old and new config formats. After deprecation period, old utility is removed. + +Pros: modularize it and remove from the recipe. Future changes will be easier to implement. +Cons: Big change in how we handle recipe utilities. + +2. Create an adapter migrate_old_to_new_config: +Pros: Recipes still have method _setup_data exposing the logic +Cons: Hard to debug the migrated configs, edge cases not covered by the adapter, ConcatDataset is handled differently. + +3. No migration. Old config with old recipe will break. Users need to update +their configs. No idea how this affects llamastack / startups / others. + +**Implementation of option 1 (Make setup_data an utility)** + +# torchtune/training/data_utils.py or similar location +```python +@deprecated +def is_legacy_data_config(cfg: DictConfig) -> bool: + """ + Detect if config follows legacy format vs new iterable dataset format. + """ + # Check for new format indicators first + has_dataloader_section = "dataloader" in cfg + has_dataset_defaults = "dataset_defaults" in cfg + has_dataset_setup = "dataset_setup" in cfg + + return not (has_dataloader_section or has_dataset_defaults or has_dataset_setup) + +@deprecated +def setup_data_legacy( + ... +) -> StatefulDataLoader: + """ + Legacy data setup function to maintain backward compatibility. + This replicates the current behavior in full_finetune_distributed.py + """ + # same as current setup_data in the recipe.... + + return dataloader +``` + +In the recipe: +```python +def _setup(...): + ... + if is_legacy_data_config(cfg): + dataloader = setup_data_legacy(...) + else: + dataloader = setup_data(...) +``` From ab02d75ad65e8096787a8ec544fb7c5c90b2b61b Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 4 Jun 2025 12:03:01 -0700 Subject: [PATCH 2/3] Revert "rfc" This reverts commit 901723f58ba5cbd8201a3fbbcc21550857159ff9. --- torchtune/datasets/rfc_iterable_dataset.md | 404 --------------------- 1 file changed, 404 deletions(-) delete mode 100644 torchtune/datasets/rfc_iterable_dataset.md diff --git a/torchtune/datasets/rfc_iterable_dataset.md b/torchtune/datasets/rfc_iterable_dataset.md deleted file mode 100644 index a8efd553e6..0000000000 --- a/torchtune/datasets/rfc_iterable_dataset.md +++ /dev/null @@ -1,404 +0,0 @@ -### Core issues: - 1) No support for iterative dataset: - - Dataset has to be fully loaded in memory; - - With map-style, no control over multi-sample operations, e.g. packing or skipping - - map-style is slower - - no support for streaming - 2) No support for weighted dataset: - - We have it in a single newly added dev recipe/config, but API needs polishing; - - We also support ConcatDataset, but its map style and there is no weighting; - 3) No support for on-the-fly data packing: It's done before training, taking a long time for large datasets; - -### UX issues: - 4) Unclear boundaries between HF and torchtune args - - ```python - def alpaca_dataset( - # --- torchtune specific args --- - tokenizer: ModelTokenizer, - train_on_input: bool = True, - packed: bool = False, - # --- HF loading args --- - source: str = "tatsu-lab/alpaca", - column_map: Optional[Dict[str, str]] = None, - split: str = "train", - **load_dataset_kwargs: Dict[str, Any], - # --- HF dataset method --- - filter_fn: Optional[Callable] = None, - ) -> Union[SFTDataset, PackedDataset]: - ``` - - 5) Lack of dataloader args: args are scattered in the config. Important args are not exposed, e.g. num_workers, pin_memory, etc. - ```yaml - dataset: - _component_: torchtune.datasets.multimodal.the_cauldron_dataset - seed: null - batch_size: 8 - shuffle: True - collate_fn: torchtune.data.padded_collate_tiled_images_and_mask - ``` - - 6) Different datasets have different arguments, because their message transforms are different. - -### Principles: - - Common API signatures for all datasets - - Offload what we can to hf datasets methods directly - - Less protagonism from our functions. E.g. config manipulations, instantiation, etc. (not the focus of this diff) - -### Proposal: - -# config.yaml -```yaml - -########### -# tokenizer -########### -tokenizer: - _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform - path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model - image_size: 560 - max_seq_len: 8192 - -########## -# dataloader -# consolidate all dataloader args here, which are currently scattered -########## -dataloader: - _component_: torchdata.stateful_dataloader.StatefulDataLoader - batch_size: 4 - num_workers: 4 - pin_memory: true - collate_fn: torchtune.data.padded_collate - - -#-------------------------------- -######### -# dataset if the class is used directly, as in our current SFTDataset -######### -dataset: - - _component_: torchtune.datasets.HfIterableDataset - load_args: - path: "tatsu-lab/alpaca" - split: "train" - message_transform: - _component_: torchtune.datasets.alpaca_message_transform - masking_strategy: "output_only" - column_map: - input: "prompt" - output: "response" - system_prompt: "foo" - filter_args: - function: torchtune.datasets.filter_fn_even_indices - with_indices: True - weight: 0.8 - - _component_: torchtune.datasets.HfIterableDataset - load_args: - path: "tatsu-lab/gsm8k" - split: "train" - message_transform: - _component_: torchtune.datasets.gsm8k_message_transform - masking_strategy: "output_only" - column_map: - input: "prompt" - output: "response" - system_prompt: "bar" - weight: 0.2 - -######### -# OR with builders -# TODO: test indexing "`tune run config – dataset[0].load_arg.split=train`" -######### -dataset: - - _component_: torchtune.datasets.build_alpaca_dataset - load_args: - split: "valid" - weight: 0.8 - - _component_: torchtune.datasets.build_gsm8k_dataset - message_transform: - system_prompt: "bar" - weight: 0.2 - -######### -# OR for a single dataset -######### -dataset: - _component_: torchtune.datasets.build_alpaca_dataset -#-------------------------------- - -######### -# Place for args common for all datasets that will be passed to the dataset constructor -# useful for multidataset. Used as cfg = dataset_defaults.update(dataset_cfg) -######### -dataset_defaults: - shuffle_buffer_size: 1000 - num_shards_per_worker: 16 - seed: ${seed} - tokenizer: ${tokenizer} - recipe_transform: - _component_: torchtune.datasets.SFTTransform - -######### -# args used in the dataset setup. This is not dataset specific. -######### -dataset_setup: - packing: - _component_: torchtune.datasets.packing.SFTPacking - max_seq_len: ${tokenizer.max_seq_len} - multidataset_stopping_strategy: "first_exhausted" # "all_exhausted" -``` - -# Builder example: torchtune/datasets/alpaca_dataset.py - -```python -def alpaca_dataset( - *, - load_args: optional[dict], - message_transform: optional[callable|dict], - tokenizer: ModelTokenizer, - recipe_transform: callable, - *args, **kwargs - ): - _load_args = { - source: "tatsu-lab/alpaca", - split: str = "train" - } - _message_transform_args = { - "train_on_input":False, - "column_map"={"input": "prompt", "output": "response"} - } - - # unify args - if load_args: - _load_args.update(**load_args) - - # unify args - if not message_transform or isinstance(message_transform, dict): - # remove component key, since we are using alpaca_message_transform as default - message_transform.pop("_component_", None) - - # instantiate the message transform - _message_transform_args.update(message_transform) - message_transform = alpaca_message_transform(**_message_transform_args) - - return HfIterableDataset(load_args, message_transform, tokenizer, recipe_transform, *args, **kwargs) -``` - -# Iterable dataset: Shared for all datasets and recipes (SFT, DPO, etc). Differences are in the transforms. -Location: torchtune/datasets/hf_iterable_dataset.py - -```python -class HfIterableDataset(IterableDataset, Stateful): - def __init__( - self, - *. - load_args: Dict, - message_transform: Callable, - tokenizer: Callable, - recipe_transform: Callable, - shuffle_buffer_size:Optional[int] = 1000, - seed:Optional[int] = 42 - num_shards_per_worker: int = 16, - weight:float = 1.0, - filter_args: Optional[Dict] = None, - *args, **kwargs - ): - """Initialize a single dataset with its specific transformations.""" - self.weight = weight - - world_size = 1 - if torch.distributed.is_initialized(): - world_size = torch.distributed.get_world_size() - - #TODO: Maybe # shards should be based on dataset size, if we know it - num_shards = world_size * num_shards_per_worker - ds = load_dataset(**load_args) - ds = ds.to_iterable_dataset(num_shards) - - if filter_args: - function = filter_args.get("function", None) - if function and not isinstance(function, Callable): - raise ValueError(f"filter_args['function'] must be a callable. Found {type(function)}") - # https://huggingface.co/docs/datasets/v3.6.0/en/stream#filter - ds = ds.filter(**filter_args) - - def _apply_transforms(sample): - sample = message_transform(sample) - sample = tokenizer(sample) - return recipe_transform(sample) - - ds = ds.map(_apply_transforms) #lazy - - if shuffle_buffer_size and shuffle_buffer_size > 0: - ds = ds.shuffle(shuffle_buffer_size, seed) - - # distribute - if world_size>1: - ds = split_dataset_by_node( - ds, - rank=torch.distributed.get_rank(), - world_size=world_size, - ) - - self.ds = ds - - def __iter__(self): - # Expose the for loop so extra logic can be added here, e.g. drop if no trainable tokens - # TODO: should we add try/except to handle/logerrors? - for sample in self.ds: - yield sample - - def state_dict(self): - state_dict = self.ds.state_dict() - state_dict["weight"] = self.weight - return state_dict - - def load_state_dict(self, state_dict): - self.weight = state_dict.pop("weight") - self.ds.load_state_dict(state_dict) -``` - -# Setup Data -Method in recipes/full_distributed.py -OR utility used in the recipe - -```python -from datasets import interleave_datasets, split_dataset_by_node -from torchtune.models.tokenizers import ModelTokenizer -import torch - -#NOTE: I have mixed feelings about passing multiple configDict to setup_data. This feels hard for the user to know what they should contain. On the other hand, i) setup_data doesnt need to make assumptions about the configs ii) we already do it currently. Alternative: use dataclassses - -def setup_data( - dataset_cfg: ConfigDict, - dataset_defaults: ConfigDict, - data_setup_cfg: ConfigDict, - dataloader_cfg: ConfigDict, - seed: int, - pad_idx: int, - ignore_idx: int, - pad_to_multiple_of: int, - ) -> "IterableDataset": - """ - Equivalent to setup_data in the recipe - """ - iterable_datasets = [] - weights = [] - dataset_defaults = {} if dataset_defaults is None else dataset_defaults - - # add to a list just for processing - if not isinstance(dataset_cfg, list): - dataset_cfg = [dataset_cfg] - - for base_cfg in dataset_cfg: - weight = base_cfg.get("weight", 1.0) - weights.append(weight) - - base_cfg = OmegaConf.merge(dataset_defaults, base_cfg) - ds = instantiate(base_cfg) - iterable_datasets.append(ds) - - - # interleave for multidataset - if len(iterable_datasets) > 1: - weights = normalize_weights(weights) # sum to 1 - ds = interleave_datasets( - iterable_datasets, - probabilities=weights, - seed=seed, - # strategies: https://huggingface.co/docs/datasets/v3.3.2/en/package_reference/main_classes#datasets.interleave_datasets.stopping_strategy - stopping_strategy=data_setup_cfg.multidataset_stopping_strategy, - ) - else: - ds = iterable_datasets[0] - - # FIXME: remove from config - if setup_cfg.packing: - # Subclass of IterableDataset, takes any iterator as input - ds = instantiate(data_setup_cfg.packing, - dataset=ds, - padding_idx=pad_id, #TODO: in the future, move padding to collate_fn - ) - - # Instantiate collate_fn - collate_fn = dataloader_cfg.pop("collate_fn", None) - #TODO: in the future, unify those two - if collate_fn is None: - collate_fn = "torchtune.data.padded_collate_packed" if packing else "torchtune.data.padded_collate_sft" - - collate_fn = _get_component_from_path(collate_fn) - collate_fn = partial(collate_fn, - padding_idx=pad_idx, - ignore_idx=ignore_id, - pad_to_multiple_of=pad_to_multiple_of - ) - - # dropping last avoids shape issues with compile + flex attention - if "drop_last" not in dataloader_cfg: - dataloader_cfg["drop_last"] = True - - dataloader = instantiate(dataloader_cfg, dataset=ds, collate_fn=collate_fn) - - return dataloader -``` - -# Recipe train loop -```python -for epoch in range(n_epochs): - my_iterable_dataset.set_epoch(epoch) - for example in my_iterable_dataset: # fast + reshuffled at each epoch using `effective_seed = seed + epoch` - pass -``` - -### Backward compatibility - -Options: -1. Make setup_data an utility, and have two utilities supporting the old and new config formats. After deprecation period, old utility is removed. - -Pros: modularize it and remove from the recipe. Future changes will be easier to implement. -Cons: Big change in how we handle recipe utilities. - -2. Create an adapter migrate_old_to_new_config: -Pros: Recipes still have method _setup_data exposing the logic -Cons: Hard to debug the migrated configs, edge cases not covered by the adapter, ConcatDataset is handled differently. - -3. No migration. Old config with old recipe will break. Users need to update -their configs. No idea how this affects llamastack / startups / others. - -**Implementation of option 1 (Make setup_data an utility)** - -# torchtune/training/data_utils.py or similar location -```python -@deprecated -def is_legacy_data_config(cfg: DictConfig) -> bool: - """ - Detect if config follows legacy format vs new iterable dataset format. - """ - # Check for new format indicators first - has_dataloader_section = "dataloader" in cfg - has_dataset_defaults = "dataset_defaults" in cfg - has_dataset_setup = "dataset_setup" in cfg - - return not (has_dataloader_section or has_dataset_defaults or has_dataset_setup) - -@deprecated -def setup_data_legacy( - ... -) -> StatefulDataLoader: - """ - Legacy data setup function to maintain backward compatibility. - This replicates the current behavior in full_finetune_distributed.py - """ - # same as current setup_data in the recipe.... - - return dataloader -``` - -In the recipe: -```python -def _setup(...): - ... - if is_legacy_data_config(cfg): - dataloader = setup_data_legacy(...) - else: - dataloader = setup_data(...) -``` From dd99edd1251b3501ee5e2ed3128916f156ecd008 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 13 Jun 2025 09:06:54 -0700 Subject: [PATCH 3/3] raise error if opt_in_bwd and cpu offload --- recipes/full_dpo_distributed.py | 5 +++++ recipes/full_finetune_distributed.py | 5 +++++ recipes/qat_distributed.py | 5 +++++ 3 files changed, 15 insertions(+) diff --git a/recipes/full_dpo_distributed.py b/recipes/full_dpo_distributed.py index 08400067d1..c7e10708eb 100644 --- a/recipes/full_dpo_distributed.py +++ b/recipes/full_dpo_distributed.py @@ -177,6 +177,11 @@ def __init__(self, cfg: DictConfig) -> None: "Gradient accumulation is not supported with optimizer in bwd." "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." ) + if self.fsdp_cpu_offload: + raise RuntimeError( + "CPU offload is not supported with optimizer in bwd atm." + "Please set fsdp_cpu_offload=False, or optimizer_in_bwd=False." + ) # activation checkpointing/offloading self._enable_activation_checkpointing = cfg.get( diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index fc7254b3c4..3423390070 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -226,6 +226,11 @@ def __init__(self, cfg: DictConfig) -> None: "Gradient accumulation is not supported with optimizer in bwd." "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." ) + if self.fsdp_cpu_offload: + raise RuntimeError( + "CPU offload is not supported with optimizer in bwd atm." + "Please set fsdp_cpu_offload=False, or optimizer_in_bwd=False." + ) # activation checkpointing/offloading self._enable_activation_checkpointing = cfg.get( diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index 1f6b91f163..0e8d49905d 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -232,6 +232,11 @@ def __init__(self, cfg: DictConfig) -> None: "Gradient accumulation is not supported with optimizer in bwd." "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." ) + if self.fsdp_cpu_offload: + raise RuntimeError( + "CPU offload is not supported with optimizer in bwd atm." + "Please set fsdp_cpu_offload=False, or optimizer_in_bwd=False." + ) self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False) self._checkpoint_client = CheckpointClient(cfg)