Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 59 additions & 15 deletions tests/test_sharding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import glob
import sys, os
import re

# Ring attention only works efficiently with the latency-hiding scheduler.
os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true'
Expand Down Expand Up @@ -45,7 +46,7 @@ def check1(ref_out, jax_out, out):
def test_flash_fwd_sharded_hlo(seqlen, h, d, m, causal, local, dtype):
window_size = (3,3) if local else (-1,-1)

devices = jax.local_devices()[:4]
devices = jax.local_devices()[:2]
n = len(devices)

@jax.jit
Expand Down Expand Up @@ -81,7 +82,12 @@ def with_sharding(q_sharding, kv_sharding=None) -> str:
assert 'dynamic-slice' not in hlo
assert 'collective-permute' in hlo
# Should always run concurrently, meaning custom-call is always between start and done.
assert 'collective-permute-start collective-permute-done' not in decode_hlo(hlo), hlo
# Forward pass should have all rotations overlapped (no final rotation needed).
decoded = decode_hlo(hlo)
overlapped_pairs, adjacent_pairs = count_overlapped_permutes(decoded)
assert adjacent_pairs == 0, f"Found non-overlapped rotations: {adjacent_pairs} (expected 0). Decoded: {decoded}"
# (N-1) overlapped rotations to see all N blocks.
assert overlapped_pairs == n-1, f"Wrong number of overlapped rotations: {overlapped_pairs} (expected exactly {n-1}). Decoded: {decoded}"


@pytest.mark.skipif(len(jax.local_devices()) < 2, reason='Requires >1 gpu device')
Expand All @@ -95,7 +101,7 @@ def with_sharding(q_sharding, kv_sharding=None) -> str:
def test_flash_bwd_sharded_hlo(seqlen, h, d, m, causal, local, dtype):
window_size = (3,3) if local else (-1,-1)

devices = jax.local_devices()[:4]
devices = jax.local_devices()[:2]
mesh = Mesh(np.array(devices), axis_names=('x',))
n = len(devices)

Expand Down Expand Up @@ -129,11 +135,13 @@ def with_sharding(sharding) -> str:
assert 'dynamic-slice' not in hlo
assert 'collective-permute' in hlo
# Should always run concurrently, meaning custom-call is always between start and done.
# import re
# collectives = ''.join(re.findall(" collective-permute-start| collective-permute-done| custom-call", hlo))
# assert 'collective-permute-start collective-permute-done' not in collectives, hlo
print(hlo)
assert 'collective-permute-start collective-permute-done' not in decode_hlo(hlo), decode_hlo(hlo)
# In backward pass, there's one final rotation after the scan loop that
# cannot overlap (returns gradients).
decoded = decode_hlo(hlo)
overlapped_pairs, adjacent_pairs = count_overlapped_permutes(decoded)
# Backward pass: N overlapped rotations in scan + 1 final non-overlapped
assert overlapped_pairs == n, f"Wrong number of overlapped rotations: {overlapped_pairs} (expected exactly {n}). Decoded: {decoded}"
assert adjacent_pairs == 1, f"Wrong number of non-overlapped rotations: {adjacent_pairs} (expected exactly 1). Decoded: {decoded}"

@pytest.mark.skipif(len(jax.local_devices()) < 2, reason='Requires >1 gpu device')
@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16])
Expand All @@ -146,7 +154,7 @@ def with_sharding(sharding) -> str:
def test_flash_fwd_sharded(seqlen, h, d, m, causal, local, dtype):
window_size = (3,3) if local else (-1,-1)

devices = jax.local_devices()
devices = jax.local_devices()[:2]
mesh = Mesh(np.array(devices), axis_names=('x',))
n = len(devices)

Expand Down Expand Up @@ -193,7 +201,7 @@ def check_sharding(sharding,q,k,v):
def test_flash_bwd_sharded(seqlen, h, d, m, causal, local, dtype):
window_size = (3,3) if local else (-1,-1)

devices = jax.local_devices()
devices = jax.local_devices()[:2]
mesh = Mesh(np.array(devices), axis_names=('x',))
n = len(devices)

Expand Down Expand Up @@ -235,7 +243,25 @@ def check_sharding(sharding):
sharding = NamedSharding(mesh, P(None,'x',None,None))
check_sharding(sharding)

def count_overlapped_permutes(decoded_ops):
ops = decoded_ops.split()
adjacent_pairs = 0
overlapped_pairs = 0
i = 0
while i < len(ops) - 1:
if ops[i] == 'collective-permute-start':
if ops[i+1] == 'collective-permute-done':
adjacent_pairs += 1
i += 2
else:
overlapped_pairs += 1
i += 1
else:
i += 1
return overlapped_pairs, adjacent_pairs

def decode_hlo(hlo):
import re
computations = {}
current_name = None
current_lines = []
Expand All @@ -251,18 +277,36 @@ def decode_hlo(hlo):
computations[current_name] = current_lines

def visit(name):
if name not in computations:
return
for line in computations[name]:
if 'custom-call(' in line:
yield 'custom-call'
elif any('calls='+target in line for target in computations.keys()):
target = [target for target in computations.keys() if 'calls='+target in line][0]
for item in visit(target):
yield item
# Handle calls=, body=, condition= (for regular calls and while loops)
elif 'calls=' in line or 'body=' in line or 'condition=' in line:
# Extract all referenced computation names
targets = []
for match in re.finditer(r'(?:calls|body|condition)=(%[^,\s\)]+)', line):
target = match.group(1)
if target in computations:
targets.append(target)
for target in targets:
for item in visit(target):
yield item
# Handle branch_computations={...} (for conditional/switch operations)
elif 'branch_computations=' in line:
# Extract branch names from branch_computations={%branch1, %branch2, ...}
match = re.search(r'branch_computations=\{([^}]+)\}', line)
if match:
branches = [b.strip() for b in match.group(1).split(',')]
for branch in branches:
if branch in computations:
for item in visit(branch):
yield item
elif 'collective-permute-start(' in line:
yield 'collective-permute-start'
elif 'collective-permute-done(' in line:
yield 'collective-permute-done'

return ' '.join(visit('ENTRY'))

if __name__ == '__main__':
Expand Down