diff --git a/tests/test_sharding.py b/tests/test_sharding.py index 1782b83..cc72d3e 100644 --- a/tests/test_sharding.py +++ b/tests/test_sharding.py @@ -1,5 +1,6 @@ import glob import sys, os +import re # Ring attention only works efficiently with the latency-hiding scheduler. os.environ["XLA_FLAGS"] = '--xla_gpu_enable_latency_hiding_scheduler=true' @@ -45,7 +46,7 @@ def check1(ref_out, jax_out, out): def test_flash_fwd_sharded_hlo(seqlen, h, d, m, causal, local, dtype): window_size = (3,3) if local else (-1,-1) - devices = jax.local_devices()[:4] + devices = jax.local_devices()[:2] n = len(devices) @jax.jit @@ -81,7 +82,12 @@ def with_sharding(q_sharding, kv_sharding=None) -> str: assert 'dynamic-slice' not in hlo assert 'collective-permute' in hlo # Should always run concurrently, meaning custom-call is always between start and done. - assert 'collective-permute-start collective-permute-done' not in decode_hlo(hlo), hlo + # Forward pass should have all rotations overlapped (no final rotation needed). + decoded = decode_hlo(hlo) + overlapped_pairs, adjacent_pairs = count_overlapped_permutes(decoded) + assert adjacent_pairs == 0, f"Found non-overlapped rotations: {adjacent_pairs} (expected 0). Decoded: {decoded}" + # (N-1) overlapped rotations to see all N blocks. + assert overlapped_pairs == n-1, f"Wrong number of overlapped rotations: {overlapped_pairs} (expected exactly {n-1}). Decoded: {decoded}" @pytest.mark.skipif(len(jax.local_devices()) < 2, reason='Requires >1 gpu device') @@ -95,7 +101,7 @@ def with_sharding(q_sharding, kv_sharding=None) -> str: def test_flash_bwd_sharded_hlo(seqlen, h, d, m, causal, local, dtype): window_size = (3,3) if local else (-1,-1) - devices = jax.local_devices()[:4] + devices = jax.local_devices()[:2] mesh = Mesh(np.array(devices), axis_names=('x',)) n = len(devices) @@ -129,11 +135,13 @@ def with_sharding(sharding) -> str: assert 'dynamic-slice' not in hlo assert 'collective-permute' in hlo # Should always run concurrently, meaning custom-call is always between start and done. - # import re - # collectives = ''.join(re.findall(" collective-permute-start| collective-permute-done| custom-call", hlo)) - # assert 'collective-permute-start collective-permute-done' not in collectives, hlo - print(hlo) - assert 'collective-permute-start collective-permute-done' not in decode_hlo(hlo), decode_hlo(hlo) + # In backward pass, there's one final rotation after the scan loop that + # cannot overlap (returns gradients). + decoded = decode_hlo(hlo) + overlapped_pairs, adjacent_pairs = count_overlapped_permutes(decoded) + # Backward pass: N overlapped rotations in scan + 1 final non-overlapped + assert overlapped_pairs == n, f"Wrong number of overlapped rotations: {overlapped_pairs} (expected exactly {n}). Decoded: {decoded}" + assert adjacent_pairs == 1, f"Wrong number of non-overlapped rotations: {adjacent_pairs} (expected exactly 1). Decoded: {decoded}" @pytest.mark.skipif(len(jax.local_devices()) < 2, reason='Requires >1 gpu device') @pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16]) @@ -146,7 +154,7 @@ def with_sharding(sharding) -> str: def test_flash_fwd_sharded(seqlen, h, d, m, causal, local, dtype): window_size = (3,3) if local else (-1,-1) - devices = jax.local_devices() + devices = jax.local_devices()[:2] mesh = Mesh(np.array(devices), axis_names=('x',)) n = len(devices) @@ -193,7 +201,7 @@ def check_sharding(sharding,q,k,v): def test_flash_bwd_sharded(seqlen, h, d, m, causal, local, dtype): window_size = (3,3) if local else (-1,-1) - devices = jax.local_devices() + devices = jax.local_devices()[:2] mesh = Mesh(np.array(devices), axis_names=('x',)) n = len(devices) @@ -235,7 +243,25 @@ def check_sharding(sharding): sharding = NamedSharding(mesh, P(None,'x',None,None)) check_sharding(sharding) +def count_overlapped_permutes(decoded_ops): + ops = decoded_ops.split() + adjacent_pairs = 0 + overlapped_pairs = 0 + i = 0 + while i < len(ops) - 1: + if ops[i] == 'collective-permute-start': + if ops[i+1] == 'collective-permute-done': + adjacent_pairs += 1 + i += 2 + else: + overlapped_pairs += 1 + i += 1 + else: + i += 1 + return overlapped_pairs, adjacent_pairs + def decode_hlo(hlo): + import re computations = {} current_name = None current_lines = [] @@ -251,18 +277,36 @@ def decode_hlo(hlo): computations[current_name] = current_lines def visit(name): + if name not in computations: + return for line in computations[name]: if 'custom-call(' in line: yield 'custom-call' - elif any('calls='+target in line for target in computations.keys()): - target = [target for target in computations.keys() if 'calls='+target in line][0] - for item in visit(target): - yield item + # Handle calls=, body=, condition= (for regular calls and while loops) + elif 'calls=' in line or 'body=' in line or 'condition=' in line: + # Extract all referenced computation names + targets = [] + for match in re.finditer(r'(?:calls|body|condition)=(%[^,\s\)]+)', line): + target = match.group(1) + if target in computations: + targets.append(target) + for target in targets: + for item in visit(target): + yield item + # Handle branch_computations={...} (for conditional/switch operations) + elif 'branch_computations=' in line: + # Extract branch names from branch_computations={%branch1, %branch2, ...} + match = re.search(r'branch_computations=\{([^}]+)\}', line) + if match: + branches = [b.strip() for b in match.group(1).split(',')] + for branch in branches: + if branch in computations: + for item in visit(branch): + yield item elif 'collective-permute-start(' in line: yield 'collective-permute-start' elif 'collective-permute-done(' in line: yield 'collective-permute-done' - return ' '.join(visit('ENTRY')) if __name__ == '__main__':