Skip to content

Commit fc02016

Browse files
committed
works now with torch nightly
1 parent b214b54 commit fc02016

File tree

1 file changed

+0
-4
lines changed

1 file changed

+0
-4
lines changed

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,6 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
182182
print("result:", d_dn(y_colwise))
183183
print("Distributed works!")
184184

185-
# doesn't work
186-
# [rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, 1024)), device_mesh=DeviceMesh('cuda', [0, 1,
187-
# 2, 3]), placements=(Replicate(),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(128, 1024)), shape=torch.Size([1024, 1024]), device=cuda:0, dtype=torch.float32, requires_grad=False), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), None), **{}):
188-
# [rank0]: a and b must have same reduction dim, but got [128, 1024] X [128, 1024].
189185
c_up = torch.compile(d_up)
190186
y_up = c_up(input_dtensor)
191187
print("y_up:", y_up.shape)

0 commit comments

Comments
 (0)