Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 82 additions & 141 deletions crates/core_arch/src/x86/sse2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,65 +412,48 @@ pub unsafe fn _mm_subs_epu16(a: __m128i, b: __m128i) -> __m128i {
#[inline]
#[target_feature(enable = "sse2")]
#[cfg_attr(test, assert_instr(pslldq, imm8 = 1))]
#[rustc_args_required_const(1)]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_slli_si128(a: __m128i, imm8: i32) -> __m128i {
_mm_slli_si128_impl(a, imm8)
pub unsafe fn _mm_slli_si128<const imm8: i32>(a: __m128i) -> __m128i {
static_assert_imm8!(imm8);
_mm_slli_si128_impl::<imm8>(a)
}

/// Implementation detail: converts the immediate argument of the
/// `_mm_slli_si128` intrinsic into a compile-time constant.
#[inline]
#[target_feature(enable = "sse2")]
unsafe fn _mm_slli_si128_impl(a: __m128i, imm8: i32) -> __m128i {
let (zero, imm8) = (_mm_set1_epi8(0).as_i8x16(), imm8 as u32);
let a = a.as_i8x16();
macro_rules! shuffle {
($shift:expr) => {
simd_shuffle16::<i8x16, i8x16>(
zero,
a,
[
16 - $shift,
17 - $shift,
18 - $shift,
19 - $shift,
20 - $shift,
21 - $shift,
22 - $shift,
23 - $shift,
24 - $shift,
25 - $shift,
26 - $shift,
27 - $shift,
28 - $shift,
29 - $shift,
30 - $shift,
31 - $shift,
],
)
};
unsafe fn _mm_slli_si128_impl<const imm8: i32>(a: __m128i) -> __m128i {
const fn mask(shift: i32, i: u32) -> u32 {
if (shift as u32) > 15 {
i
} else {
16 - (shift as u32) + i
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does using a const fn work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, const fn works here.

let x = match imm8 {
0 => shuffle!(0),
1 => shuffle!(1),
2 => shuffle!(2),
3 => shuffle!(3),
4 => shuffle!(4),
5 => shuffle!(5),
6 => shuffle!(6),
7 => shuffle!(7),
8 => shuffle!(8),
9 => shuffle!(9),
10 => shuffle!(10),
11 => shuffle!(11),
12 => shuffle!(12),
13 => shuffle!(13),
14 => shuffle!(14),
15 => shuffle!(15),
_ => shuffle!(16),
};
transmute(x)
let zero = _mm_set1_epi8(0).as_i8x16();
transmute(simd_shuffle16::<i8x16, i8x16>(
zero,
a.as_i8x16(),
[
mask(imm8, 0),
mask(imm8, 1),
mask(imm8, 2),
mask(imm8, 3),
mask(imm8, 4),
mask(imm8, 5),
mask(imm8, 6),
mask(imm8, 7),
mask(imm8, 8),
mask(imm8, 9),
mask(imm8, 10),
mask(imm8, 11),
mask(imm8, 12),
mask(imm8, 13),
mask(imm8, 14),
mask(imm8, 15),
],
))
}

/// Shifts `a` left by `imm8` bytes while shifting in zeros.
Expand All @@ -479,10 +462,11 @@ unsafe fn _mm_slli_si128_impl(a: __m128i, imm8: i32) -> __m128i {
#[inline]
#[target_feature(enable = "sse2")]
#[cfg_attr(test, assert_instr(pslldq, imm8 = 1))]
#[rustc_args_required_const(1)]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_bslli_si128(a: __m128i, imm8: i32) -> __m128i {
_mm_slli_si128_impl(a, imm8)
pub unsafe fn _mm_bslli_si128<const imm8: i32>(a: __m128i) -> __m128i {
static_assert_imm8!(imm8);
_mm_slli_si128_impl::<imm8>(a)
}

/// Shifts `a` right by `imm8` bytes while shifting in zeros.
Expand All @@ -491,10 +475,11 @@ pub unsafe fn _mm_bslli_si128(a: __m128i, imm8: i32) -> __m128i {
#[inline]
#[target_feature(enable = "sse2")]
#[cfg_attr(test, assert_instr(psrldq, imm8 = 1))]
#[rustc_args_required_const(1)]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_bsrli_si128(a: __m128i, imm8: i32) -> __m128i {
_mm_srli_si128_impl(a, imm8)
pub unsafe fn _mm_bsrli_si128<const imm8: i32>(a: __m128i) -> __m128i {
static_assert_imm8!(imm8);
_mm_srli_si128_impl::<imm8>(a)
}

/// Shifts packed 16-bit integers in `a` left by `imm8` while shifting in zeros.
Expand Down Expand Up @@ -630,64 +615,48 @@ pub unsafe fn _mm_sra_epi32(a: __m128i, count: __m128i) -> __m128i {
#[inline]
#[target_feature(enable = "sse2")]
#[cfg_attr(test, assert_instr(psrldq, imm8 = 1))]
#[rustc_args_required_const(1)]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_srli_si128(a: __m128i, imm8: i32) -> __m128i {
_mm_srli_si128_impl(a, imm8)
pub unsafe fn _mm_srli_si128<const imm8: i32>(a: __m128i) -> __m128i {
static_assert_imm8!(imm8);
_mm_srli_si128_impl::<imm8>(a)
}

/// Implementation detail: converts the immediate argument of the
/// `_mm_srli_si128` intrinsic into a compile-time constant.
#[inline]
#[target_feature(enable = "sse2")]
unsafe fn _mm_srli_si128_impl(a: __m128i, imm8: i32) -> __m128i {
let (zero, imm8) = (_mm_set1_epi8(0).as_i8x16(), imm8 as u32);
let a = a.as_i8x16();
macro_rules! shuffle {
($shift:expr) => {
simd_shuffle16(
a,
zero,
[
0 + $shift,
1 + $shift,
2 + $shift,
3 + $shift,
4 + $shift,
5 + $shift,
6 + $shift,
7 + $shift,
8 + $shift,
9 + $shift,
10 + $shift,
11 + $shift,
12 + $shift,
13 + $shift,
14 + $shift,
15 + $shift,
],
)
};
unsafe fn _mm_srli_si128_impl<const imm8: i32>(a: __m128i) -> __m128i {
const fn mask(shift: i32, i: u32) -> u32 {
if (shift as u32) > 15 {
i + 16
} else {
i + (shift as u32)
}
}
let x: i8x16 = match imm8 {
0 => shuffle!(0),
1 => shuffle!(1),
2 => shuffle!(2),
3 => shuffle!(3),
4 => shuffle!(4),
5 => shuffle!(5),
6 => shuffle!(6),
7 => shuffle!(7),
8 => shuffle!(8),
9 => shuffle!(9),
10 => shuffle!(10),
11 => shuffle!(11),
12 => shuffle!(12),
13 => shuffle!(13),
14 => shuffle!(14),
15 => shuffle!(15),
_ => shuffle!(16),
};
let zero = _mm_set1_epi8(0).as_i8x16();
let x: i8x16 = simd_shuffle16(
a.as_i8x16(),
zero,
[
mask(imm8, 0),
mask(imm8, 1),
mask(imm8, 2),
mask(imm8, 3),
mask(imm8, 4),
mask(imm8, 5),
mask(imm8, 6),
mask(imm8, 7),
mask(imm8, 8),
mask(imm8, 9),
mask(imm8, 10),
mask(imm8, 11),
mask(imm8, 12),
mask(imm8, 13),
mask(imm8, 14),
mask(imm8, 15),
],
);
transmute(x)
}

Expand Down Expand Up @@ -3375,37 +3344,23 @@ mod tests {
let a = _mm_setr_epi8(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
);
let r = _mm_slli_si128(a, 1);
let r = _mm_slli_si128::<1>(a);
let e = _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
assert_eq_m128i(r, e);

#[rustfmt::skip]
let a = _mm_setr_epi8(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
);
let r = _mm_slli_si128(a, 15);
let r = _mm_slli_si128::<15>(a);
let e = _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1);
assert_eq_m128i(r, e);

#[rustfmt::skip]
let a = _mm_setr_epi8(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
);
let r = _mm_slli_si128(a, 16);
assert_eq_m128i(r, _mm_set1_epi8(0));

#[rustfmt::skip]
let a = _mm_setr_epi8(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
);
let r = _mm_slli_si128(a, -1);
assert_eq_m128i(_mm_set1_epi8(0), r);

#[rustfmt::skip]
let a = _mm_setr_epi8(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
);
let r = _mm_slli_si128(a, -0x80000000);
let r = _mm_slli_si128::<16>(a);
assert_eq_m128i(r, _mm_set1_epi8(0));
}

Expand Down Expand Up @@ -3496,7 +3451,7 @@ mod tests {
let a = _mm_setr_epi8(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
);
let r = _mm_srli_si128(a, 1);
let r = _mm_srli_si128::<1>(a);
#[rustfmt::skip]
let e = _mm_setr_epi8(
2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 0,
Expand All @@ -3507,29 +3462,15 @@ mod tests {
let a = _mm_setr_epi8(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
);
let r = _mm_srli_si128(a, 15);
let r = _mm_srli_si128::<15>(a);
let e = _mm_setr_epi8(16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
assert_eq_m128i(r, e);

#[rustfmt::skip]
let a = _mm_setr_epi8(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
);
let r = _mm_srli_si128(a, 16);
assert_eq_m128i(r, _mm_set1_epi8(0));

#[rustfmt::skip]
let a = _mm_setr_epi8(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
);
let r = _mm_srli_si128(a, -1);
assert_eq_m128i(r, _mm_set1_epi8(0));

#[rustfmt::skip]
let a = _mm_setr_epi8(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
);
let r = _mm_srli_si128(a, -0x80000000);
let r = _mm_srli_si128::<16>(a);
assert_eq_m128i(r, _mm_set1_epi8(0));
}

Expand Down