Skip to content

[BUG] ROCm 7.0 compatibility issue with Attention2d modules #2613

@EmilienM

Description

@EmilienM

Describe the bug

The timm.layers.Attention2d and timm.layers.MultiQueryAttention2d modules fail with HIP error: invalid argument during forward pass on AMD ROCm 7.0. This failure occurs specifically when calling F.conv2d() for 1×1 convolutions within the attention modules.

Error Details

   def test_attn2d_parametrized(self, bias, expand_first, head_first, attn_mask):
        """Test Attention2d with various parameters."""
        test_name = "test_attn2d_parametrized"
    
        if not (CUDA_AVAILABLE or ROCM_AVAILABLE):
            log_test_result(test_name, "skipped", "CUDA/ROCm not available")
            pytest.skip("CUDA/ROCm not available")
    
        try:
            import torch
            from timm.layers import Attention2d
    
            device = get_device()
            x = torch.randn(1, 128, 32, 48, device=device)
            attn = Attention2d(
                128, 128, num_heads=4, bias=bias, expand_first=expand_first, head_first=head_first
            ).to(device)
    
            if attn_mask:
                mask = torch.randint(
                    0, 2, size=(32 * 48, 32 * 48), dtype=torch.float32, device=device
                )
            else:
                mask = None
    
>           o1 = attn(x, mask)
                 ^^^^^^^^^^^^^/home/hotaisle/probe-tests/probe-timm/timm_probe_test.py:598: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib64/python3.12/site-packages/timm/layers/attention2d.py:378: in forward
    x = self.proj(x)
        ^^^^^^^^^^^^
.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib64/python3.12/site-packages/torch/nn/modules/conv.py:548: in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _self = Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
input = tensor([[[[-7.8606e-01,  5.1798e-02, -7.5152e-02,  ...,  7.3231e-01,
            4.1317e-01, -1.0511e+00],
          [...1,  ...,  2.4378e-01,
            3.6133e-01, -1.5465e-01]]]], device='cuda:0',
       grad_fn=<ReshapeAliasBackward0>)
weight = Parameter containing:
tensor([[[[ 0.0392]],         [[-0.0073]],         [[-0.0370]],         ...,         [[-...       ...,         [[-0.0267]],         [[ 0.0238]],         [[ 0.0581]]]], device='cuda:0', requires_grad=True)
bias = None    def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
        if self.padding_mode != "zeros":
            return F.conv2d(
                F.pad(
                    input, self._reversed_padding_repeated_twice, mode=self.padding_mode
                ),
                weight,
                bias,
                self.stride,
                _pair(0),
                self.dilation,
                self.groups,
            )
>       return F.conv2d(
            input, weight, bias, self.stride, self.padding, self.dilation, self.groups
        )
E       torch.AcceleratorError: HIP error: invalid argument
E       Search for `hipErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__HIPRT__TYPES.html for more information.
E       HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
E       For debugging consider passing AMD_SERIALIZE_KERNEL=3
E       Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions..venv/lib64/python3.12/site-packages/torch/nn/modules/conv.py:543: AcceleratorError

Environment

  • PyTorch: 2.9+
  • ROCm: 7.0.2
  • Hardware: AMD GPUs
  • Affected modules: Attention2d, MultiQueryAttention2d

Root Cause Analysis

ROCm 7.0 introduced a significant architectural change where 1×1 convolutions are now rewritten to GEMMs to leverage the performance of hipBLASLt and specialized tensor cores for modern data types (FP8, BF16). This change imposes stricter requirements on input tensor memory layouts:

  1. Strict Contiguity: GEMM solvers require contiguous memory layouts
  2. Stride Alignment: Non-contiguous tensors from reshape/permute/slice operations violate BLAS constraints
  3. Enforcement: Unlike previous ROCm versions that used "Direct" convolution solvers with flexible stride handling, ROCm 7.0 enforces these constraints

Technical Details

  • Previous behavior: ROCm 5.x/6.x used heuristic solver selection including Direct, Winograd, FFT, and GEMM
  • New behavior: ROCm 7.0 forces GEMM path for all 1×1 convolutions
  • Impact: Non-contiguous tensors (common in vision architectures) cause hipErrorInvalidValue

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions