From e7e512df58987b7250fa752ecdcad82ac8f90eff Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 6 Oct 2025 22:40:26 +0530 Subject: [PATCH 1/7] feat: introduce new callback `on_before_on_before_optimizer_setup` --- src/lightning/pytorch/callbacks/callback.py | 8 ++++ src/lightning/pytorch/core/hooks.py | 23 +++++++++++ src/lightning/pytorch/trainer/trainer.py | 5 +++ ...hook_outputs.py => test_callback_hooks.py} | 41 +++++++++++++++++++ 4 files changed, 77 insertions(+) rename tests/tests_pytorch/callbacks/{test_callback_hook_outputs.py => test_callback_hooks.py} (58%) diff --git a/src/lightning/pytorch/callbacks/callback.py b/src/lightning/pytorch/callbacks/callback.py index 3bfb609465a83..a1fdf0b8857fd 100644 --- a/src/lightning/pytorch/callbacks/callback.py +++ b/src/lightning/pytorch/callbacks/callback.py @@ -277,6 +277,14 @@ def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called after ``loss.backward()`` and before optimizers are stepped.""" + def on_before_optimizer_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before + :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`. + + Useful when you need to make changes to the model before the optimizers are set up (e.g. freezing layers). + + """ + def on_before_optimizer_step( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer ) -> None: diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 0b0ab14244e38..a979232d7ac72 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -295,6 +295,29 @@ def on_after_backward(self) -> None: """ + def on_before_optimizer_setup(self) -> None: + """Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before + :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`. + + This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created. + It’s particularly useful for callbacks such as + :class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen + prior to optimizer setup. + + This hook runs once in fit stage, after the model + has been fully instantiated by ``configure_model``, but before optimizers are created by + ``configure_optimizers``. + + Example:: + + class MyFinetuneCallback(Callback): + def on_before_optimizer_setup(self, trainer, pl_module): + # freeze the backbone before optimizers are created + for param in pl_module.backbone.parameters(): + param.requires_grad = False + + """ + def on_before_optimizer_step(self, optimizer: Optimizer) -> None: """Called before ``optimizer.step()``. diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 5768c507e2e3f..c0d6f18166652 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -989,6 +989,11 @@ def _run( log.debug(f"{self.__class__.__name__}: configuring model") call._call_configure_model(self) + # run hook `on_before_optimizer_setup` before optimizers are set up & after model is configured + if self.state.fn == TrainerFn.FITTING: + call._call_callback_hooks(self, "on_before_optimizer_setup") + call._call_lightning_module_hook(self, "on_before_optimizer_setup") + # check if we should delay restoring checkpoint till later if not self.strategy.restore_checkpoint_after_setup: log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") diff --git a/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py b/tests/tests_pytorch/callbacks/test_callback_hooks.py similarity index 58% rename from tests/tests_pytorch/callbacks/test_callback_hook_outputs.py rename to tests/tests_pytorch/callbacks/test_callback_hooks.py index 366a924a5867c..a85b014e2a86e 100644 --- a/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py +++ b/tests/tests_pytorch/callbacks/test_callback_hooks.py @@ -56,3 +56,44 @@ def on_test_batch_end(self, outputs, *_): assert any(isinstance(c, CB) for c in trainer.callbacks) trainer.fit(model) + + +def test_callback_on_before_optimizer_setup(tmp_path): + """Tests that on_before_optimizer_step is called as expected.""" + + class CB(Callback): + def setup(self, trainer, pl_module, stage=None): + assert len(trainer.optimizers) == 0 + assert pl_module.layer is None # setup is called before `LightningModule.configure_model` + + def on_before_optimizer_setup(self, trainer, pl_module): + assert len(trainer.optimizers) == 0 + assert pl_module.layer is not None # called after `LightningModule.configure_model` + + def on_fit_start(self, trainer, pl_module): + assert len(trainer.optimizers) == 1 + assert pl_module.layer is not None # called after `LightningModule.configure_model` + + class DemoModel(BoringModel): + def __init__(self): + super().__init__() + self.layer = None # initialize layer in `configure_model` + + def configure_model(self): + import torch.nn as nn + + self.layer = nn.Linear(32, 2) + + model = DemoModel() + + trainer = Trainer( + callbacks=CB(), + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + enable_model_summary=False, + ) + + trainer.fit(model) From cdb51e202f873461c88725b78bd8a0e10a717d50 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 6 Oct 2025 22:58:39 +0530 Subject: [PATCH 2/7] update --- docs/source-pytorch/common/hooks.rst | 8 +++++++- tests/tests_pytorch/callbacks/test_callback_hooks.py | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst index 89c1c15d0413f..82ecc99c6810a 100644 --- a/docs/source-pytorch/common/hooks.rst +++ b/docs/source-pytorch/common/hooks.rst @@ -88,13 +88,19 @@ with the source of each hook indicated: │ ├── [LightningModule] │ ├── [LightningModule.configure_shared_model()] │ ├── [LightningModule.configure_model()] + | | + │ ├── on_before_optimizer_setup() + │ │ ├── [Callbacks] + │ │ └── [LightningModule] + │ | │ ├── Strategy.restore_checkpoint_before_setup │ │ ├── [LightningModule.on_load_checkpoint()] │ │ ├── [LightningModule.load_state_dict()] │ │ ├── [LightningDataModule.load_state_dict()] │ │ ├── [Callbacks.on_load_checkpoint()] │ │ └── [Callbacks.load_state_dict()] - │ └── [Strategy] + | | + │ └── [Strategy] (configures optimizers, lr schedulers, precision, etc.) │ ├── on_fit_start() │ ├── [Callbacks] diff --git a/tests/tests_pytorch/callbacks/test_callback_hooks.py b/tests/tests_pytorch/callbacks/test_callback_hooks.py index a85b014e2a86e..3e1debfa71754 100644 --- a/tests/tests_pytorch/callbacks/test_callback_hooks.py +++ b/tests/tests_pytorch/callbacks/test_callback_hooks.py @@ -64,10 +64,10 @@ def test_callback_on_before_optimizer_setup(tmp_path): class CB(Callback): def setup(self, trainer, pl_module, stage=None): assert len(trainer.optimizers) == 0 - assert pl_module.layer is None # setup is called before `LightningModule.configure_model` + assert pl_module.layer is None # called before `LightningModule.configure_model` def on_before_optimizer_setup(self, trainer, pl_module): - assert len(trainer.optimizers) == 0 + assert len(trainer.optimizers) == 0 # `LightningModule.configure_optimizers` hasn't been called yet assert pl_module.layer is not None # called after `LightningModule.configure_model` def on_fit_start(self, trainer, pl_module): From 44ebae5979924341588cb75be6f7816f3e08db15 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 6 Oct 2025 23:07:55 +0530 Subject: [PATCH 3/7] update --- .../callbacks/test_callback_hooks.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_callback_hooks.py b/tests/tests_pytorch/callbacks/test_callback_hooks.py index 3e1debfa71754..5e4bf1096a4fb 100644 --- a/tests/tests_pytorch/callbacks/test_callback_hooks.py +++ b/tests/tests_pytorch/callbacks/test_callback_hooks.py @@ -58,42 +58,53 @@ def on_test_batch_end(self, outputs, *_): trainer.fit(model) -def test_callback_on_before_optimizer_setup(tmp_path): - """Tests that on_before_optimizer_step is called as expected.""" +def test_on_before_optimizer_setup_is_called_in_correct_order(tmp_path): + """Ensure `on_before_optimizer_setup` runs after `configure_model` but before `configure_optimizers`.""" - class CB(Callback): + order = [] + + class TestCallback(Callback): def setup(self, trainer, pl_module, stage=None): + order.append("setup") + assert pl_module.layer is None assert len(trainer.optimizers) == 0 - assert pl_module.layer is None # called before `LightningModule.configure_model` def on_before_optimizer_setup(self, trainer, pl_module): - assert len(trainer.optimizers) == 0 # `LightningModule.configure_optimizers` hasn't been called yet - assert pl_module.layer is not None # called after `LightningModule.configure_model` + order.append("on_before_optimizer_setup") + # configure_model should already have been called + assert pl_module.layer is not None + # but optimizers are not yet created + assert len(trainer.optimizers) == 0 def on_fit_start(self, trainer, pl_module): + order.append("on_fit_start") + # optimizers should now exist assert len(trainer.optimizers) == 1 - assert pl_module.layer is not None # called after `LightningModule.configure_model` + assert pl_module.layer is not None class DemoModel(BoringModel): def __init__(self): super().__init__() - self.layer = None # initialize layer in `configure_model` + self.layer = None def configure_model(self): - import torch.nn as nn + from torch import nn self.layer = nn.Linear(32, 2) model = DemoModel() trainer = Trainer( - callbacks=CB(), + callbacks=TestCallback(), default_root_dir=tmp_path, limit_train_batches=2, limit_val_batches=2, max_epochs=1, - log_every_n_steps=1, enable_model_summary=False, + log_every_n_steps=1, ) trainer.fit(model) + + # Verify call order + assert order == ["setup", "on_before_optimizer_setup", "on_fit_start"] From 050f5834878a58e1988147653db321248b017c71 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 6 Oct 2025 23:32:34 +0530 Subject: [PATCH 4/7] update --- src/lightning/pytorch/callbacks/lambda_function.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/callbacks/lambda_function.py b/src/lightning/pytorch/callbacks/lambda_function.py index f04b2d777deb3..f8348ccb72a53 100644 --- a/src/lightning/pytorch/callbacks/lambda_function.py +++ b/src/lightning/pytorch/callbacks/lambda_function.py @@ -69,6 +69,7 @@ def __init__( on_load_checkpoint: Optional[Callable] = None, on_before_backward: Optional[Callable] = None, on_after_backward: Optional[Callable] = None, + on_before_optimizer_setup: Optional[Callable] = None, on_before_optimizer_step: Optional[Callable] = None, on_before_zero_grad: Optional[Callable] = None, on_predict_start: Optional[Callable] = None, From a19b3cb66117463107fd8b2caf6920e6c0a36482 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Tue, 7 Oct 2025 00:17:30 +0530 Subject: [PATCH 5/7] fx validator fix --- src/lightning/pytorch/core/hooks.py | 34 +++++++++---------- .../logger_connector/fx_validator.py | 1 + src/lightning/pytorch/trainer/trainer.py | 2 +- .../trainer/logging_/test_logger_connector.py | 3 ++ 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index a979232d7ac72..7621990f5fc95 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -295,28 +295,28 @@ def on_after_backward(self) -> None: """ - def on_before_optimizer_setup(self) -> None: - """Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before - :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`. + # def on_before_optimizer_setup(self) -> None: + # """Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before + # :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`. - This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created. - It’s particularly useful for callbacks such as - :class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen - prior to optimizer setup. + # This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created. + # It’s particularly useful for callbacks such as + # :class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen + # prior to optimizer setup. - This hook runs once in fit stage, after the model - has been fully instantiated by ``configure_model``, but before optimizers are created by - ``configure_optimizers``. + # This hook runs once in fit stage, after the model + # has been fully instantiated by ``configure_model``, but before optimizers are created by + # ``configure_optimizers``. - Example:: + # Example:: - class MyFinetuneCallback(Callback): - def on_before_optimizer_setup(self, trainer, pl_module): - # freeze the backbone before optimizers are created - for param in pl_module.backbone.parameters(): - param.requires_grad = False + # class MyFinetuneCallback(Callback): + # def on_before_optimizer_setup(self, trainer, pl_module): + # # freeze the backbone before optimizers are created + # for param in pl_module.backbone.parameters(): + # param.requires_grad = False - """ + # """ def on_before_optimizer_step(self, optimizer: Optimizer) -> None: """Called before ``optimizer.step()``. diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py index c1ee0013bfa19..01264d750ccc0 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py @@ -35,6 +35,7 @@ class _LogOptions(TypedDict): "on_after_backward": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), + "on_before_optimizer_setup": None, "on_before_optimizer_step": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index c0d6f18166652..fd96e488cbc42 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -992,7 +992,7 @@ def _run( # run hook `on_before_optimizer_setup` before optimizers are set up & after model is configured if self.state.fn == TrainerFn.FITTING: call._call_callback_hooks(self, "on_before_optimizer_setup") - call._call_lightning_module_hook(self, "on_before_optimizer_setup") + # call._call_lightning_module_hook(self, "on_before_optimizer_setup") # check if we should delay restoring checkpoint till later if not self.strategy.restore_checkpoint_after_setup: diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index d3d355edb003b..8166113974566 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -43,6 +43,7 @@ def test_fx_validator(): callbacks_func = { "on_before_backward", "on_after_backward", + "on_before_optimizer_setup", "on_before_optimizer_step", "on_before_zero_grad", "on_fit_end", @@ -83,6 +84,7 @@ def test_fx_validator(): } not_supported = { + "on_before_optimizer_setup", "on_fit_end", "on_fit_start", "on_exception", @@ -198,6 +200,7 @@ def test_fx_validator_integration(tmp_path): "setup": "You can't", "configure_model": "You can't", "configure_optimizers": "You can't", + "on_before_optimizer_setup": "You can't", "on_fit_start": "You can't", "train_dataloader": "You can't", "val_dataloader": "You can't", From e56ea5c34236fb280199505798bd92601176e0c2 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Tue, 7 Oct 2025 00:18:53 +0530 Subject: [PATCH 6/7] update --- src/lightning/pytorch/core/hooks.py | 34 ++++++++++++------------ src/lightning/pytorch/trainer/trainer.py | 2 +- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 7621990f5fc95..a979232d7ac72 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -295,28 +295,28 @@ def on_after_backward(self) -> None: """ - # def on_before_optimizer_setup(self) -> None: - # """Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before - # :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`. + def on_before_optimizer_setup(self) -> None: + """Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before + :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`. - # This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created. - # It’s particularly useful for callbacks such as - # :class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen - # prior to optimizer setup. + This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created. + It’s particularly useful for callbacks such as + :class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen + prior to optimizer setup. - # This hook runs once in fit stage, after the model - # has been fully instantiated by ``configure_model``, but before optimizers are created by - # ``configure_optimizers``. + This hook runs once in fit stage, after the model + has been fully instantiated by ``configure_model``, but before optimizers are created by + ``configure_optimizers``. - # Example:: + Example:: - # class MyFinetuneCallback(Callback): - # def on_before_optimizer_setup(self, trainer, pl_module): - # # freeze the backbone before optimizers are created - # for param in pl_module.backbone.parameters(): - # param.requires_grad = False + class MyFinetuneCallback(Callback): + def on_before_optimizer_setup(self, trainer, pl_module): + # freeze the backbone before optimizers are created + for param in pl_module.backbone.parameters(): + param.requires_grad = False - # """ + """ def on_before_optimizer_step(self, optimizer: Optimizer) -> None: """Called before ``optimizer.step()``. diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index fd96e488cbc42..c0d6f18166652 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -992,7 +992,7 @@ def _run( # run hook `on_before_optimizer_setup` before optimizers are set up & after model is configured if self.state.fn == TrainerFn.FITTING: call._call_callback_hooks(self, "on_before_optimizer_setup") - # call._call_lightning_module_hook(self, "on_before_optimizer_setup") + call._call_lightning_module_hook(self, "on_before_optimizer_setup") # check if we should delay restoring checkpoint till later if not self.strategy.restore_checkpoint_after_setup: From b80e34a2f65d8ce0a72a39cf4cc7e3d86b489bcc Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Tue, 7 Oct 2025 00:44:39 +0530 Subject: [PATCH 7/7] update --- tests/tests_pytorch/models/test_hooks.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index e943d0533cab5..77f8bbb642b2e 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -477,6 +477,8 @@ def training_step(self, batch, batch_idx): # DeepSpeed needs the batch size to figure out throughput logging *([{"name": "train_dataloader"}] if using_deepspeed else []), {"name": "configure_model"}, + {"name": "Callback.on_before_optimizer_setup", "args": (trainer, model)}, + {"name": "on_before_optimizer_setup"}, {"name": "configure_optimizers"}, {"name": "Callback.on_fit_start", "args": (trainer, model)}, {"name": "on_fit_start"}, @@ -574,6 +576,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path): {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}}, {"name": "setup", "kwargs": {"stage": "fit"}}, {"name": "configure_model"}, + {"name": "Callback.on_before_optimizer_setup", "args": (trainer, model)}, + {"name": "on_before_optimizer_setup"}, {"name": "on_load_checkpoint", "args": (loaded_ckpt,)}, {"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)}, {"name": "Callback.load_state_dict", "args": ({"foo": True},)}, @@ -654,6 +658,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path): {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}}, {"name": "setup", "kwargs": {"stage": "fit"}}, {"name": "configure_model"}, + {"name": "Callback.on_before_optimizer_setup", "args": (trainer, model)}, + {"name": "on_before_optimizer_setup"}, {"name": "on_load_checkpoint", "args": (loaded_ckpt,)}, {"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)}, {"name": "Callback.load_state_dict", "args": ({"foo": True},)},