-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
The timm.layers.Attention2d and timm.layers.MultiQueryAttention2d modules fail with HIP error: invalid argument during forward pass on AMD ROCm 7.0. This failure occurs specifically when calling F.conv2d() for 1×1 convolutions within the attention modules.
Error Details
def test_attn2d_parametrized(self, bias, expand_first, head_first, attn_mask):
"""Test Attention2d with various parameters."""
test_name = "test_attn2d_parametrized"
if not (CUDA_AVAILABLE or ROCM_AVAILABLE):
log_test_result(test_name, "skipped", "CUDA/ROCm not available")
pytest.skip("CUDA/ROCm not available")
try:
import torch
from timm.layers import Attention2d
device = get_device()
x = torch.randn(1, 128, 32, 48, device=device)
attn = Attention2d(
128, 128, num_heads=4, bias=bias, expand_first=expand_first, head_first=head_first
).to(device)
if attn_mask:
mask = torch.randint(
0, 2, size=(32 * 48, 32 * 48), dtype=torch.float32, device=device
)
else:
mask = None
> o1 = attn(x, mask)
^^^^^^^^^^^^^/home/hotaisle/probe-tests/probe-timm/timm_probe_test.py:598:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py:1786: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib64/python3.12/site-packages/timm/layers/attention2d.py:378: in forward
x = self.proj(x)
^^^^^^^^^^^^
.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py:1786: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib64/python3.12/site-packages/torch/nn/modules/conv.py:548: in forward
return self._conv_forward(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _self = Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
input = tensor([[[[-7.8606e-01, 5.1798e-02, -7.5152e-02, ..., 7.3231e-01,
4.1317e-01, -1.0511e+00],
[...1, ..., 2.4378e-01,
3.6133e-01, -1.5465e-01]]]], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>)
weight = Parameter containing:
tensor([[[[ 0.0392]], [[-0.0073]], [[-0.0370]], ..., [[-... ..., [[-0.0267]], [[ 0.0238]], [[ 0.0581]]]], device='cuda:0', requires_grad=True)
bias = None def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != "zeros":
return F.conv2d(
F.pad(
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
),
weight,
bias,
self.stride,
_pair(0),
self.dilation,
self.groups,
)
> return F.conv2d(
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
)
E torch.AcceleratorError: HIP error: invalid argument
E Search for `hipErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__HIPRT__TYPES.html for more information.
E HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
E For debugging consider passing AMD_SERIALIZE_KERNEL=3
E Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions..venv/lib64/python3.12/site-packages/torch/nn/modules/conv.py:543: AcceleratorError
Environment
- PyTorch: 2.9+
- ROCm: 7.0.2
- Hardware: AMD GPUs
- Affected modules:
Attention2d,MultiQueryAttention2d
Root Cause Analysis
ROCm 7.0 introduced a significant architectural change where 1×1 convolutions are now rewritten to GEMMs to leverage the performance of hipBLASLt and specialized tensor cores for modern data types (FP8, BF16). This change imposes stricter requirements on input tensor memory layouts:
- Strict Contiguity: GEMM solvers require contiguous memory layouts
- Stride Alignment: Non-contiguous tensors from reshape/permute/slice operations violate BLAS constraints
- Enforcement: Unlike previous ROCm versions that used "Direct" convolution solvers with flexible stride handling, ROCm 7.0 enforces these constraints
Technical Details
- Previous behavior: ROCm 5.x/6.x used heuristic solver selection including Direct, Winograd, FFT, and GEMM
- New behavior: ROCm 7.0 forces GEMM path for all 1×1 convolutions
- Impact: Non-contiguous tensors (common in vision architectures) cause
hipErrorInvalidValue
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working