You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: intermediate_source/FSDP_tutorial.rst
+1-1Lines changed: 1 addition & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -73,7 +73,7 @@ Model Initialization
73
73
# )
74
74
75
75
We can inspect the nested wrapping with ``print(model)``. ``FSDPTransformer`` is a joint class of `Transformer <https://github.com/pytorch/examples/blob/70922969e70218458d2a945bf86fd8cc967fc6ea/distributed/FSDP2/model.py#L100>`_ and `FSDPModule
76
-
<https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule>`_. The same thing happens to `FSDPTransformerBlock <https://github.com/pytorch/examples/blob/70922969e70218458d2a945bf86fd8cc967fc6ea/distributed/FSDP2/model.py#L76C7-L76C18>`_. All FSDP2 public APIs are exposed through ``FSDPModule``. For example, users can call ``model.unshard()`` to manually control all-gather schedules. See "explicit prefetching" below for details.
76
+
<https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule>`_. The same thing happens to `FSDPTransformerBlock <https://github.com/pytorch/examples/blob/70922969e70218458d2a945bf86fd8cc967fc6ea/distributed/FSDP2/model.py#L76C7-L76C18>`_. All FSDP2 public APIs are exposed through ``FSDPModule``. For example, users can call ``model.unshard()`` to manually control all-gather schedules. See "explicit prefetching" below for details.
77
77
78
78
**model.parameters() as DTensor**: ``fully_shard`` shards parameters across ranks, and convert ``model.parameters()`` from plain ``torch.Tensor`` to DTensor to represent sharded parameters. FSDP2 shards on dim-0 by default so DTensor placements are `Shard(dim=0)`. Say we have N ranks and a parameter with N rows before sharding. After sharding, each rank will have 1 row of the parameter. We can inspect sharded parameters using ``param.to_local()``.
0 commit comments