Skip to content

Commit 5368e3e

Browse files
committed
[NFC][SYCL] fix LIT tests
- Remove sub_group/common_ocl.cpp because it duplicates sub_group/common.cpp and directly use OpenCL API that causes instability on some configurations. - Fix sub_group/shuffle*.cpp tests to align with shuffle_xor restrictions mentioned in spec: "If the result of the XOR is greater than max_sub_group_size then it is considered out-of-range"
1 parent a3c3425 commit 5368e3e

File tree

2 files changed

+30
-118
lines changed

2 files changed

+30
-118
lines changed

sycl/test/sub_group/common_ocl.cpp

Lines changed: 0 additions & 106 deletions
This file was deleted.

sycl/test/sub_group/shuffle.hpp

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88

99
#include "helper.hpp"
1010
#include <CL/sycl.hpp>
11-
template <typename T, int N>
12-
class sycl_subgr;
11+
template <typename T, int N> class sycl_subgr;
1312

1413
using namespace cl::sycl;
1514

@@ -66,8 +65,9 @@ void check(queue &Queue, size_t G = 240, size_t L = 60) {
6665
acc_up[NdItem.get_global_id()] = SG.shuffle_up(vwggid, sgid);
6766
/* Save GID+SGID */
6867
acc_down[NdItem.get_global_id()] = SG.shuffle_down(vwggid, sgid);
69-
/* Save GID XOR SGID */
70-
acc_xor[NdItem.get_global_id()] = SG.shuffle_xor(vwggid, sgid);
68+
/* Save GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
69+
acc_xor[NdItem.get_global_id()] =
70+
SG.shuffle_xor(vwggid, sgid % SG.get_max_local_range()[0]);
7171
});
7272
});
7373
auto acc = buf.template get_access<access::mode::read_write>();
@@ -81,12 +81,18 @@ void check(queue &Queue, size_t G = 240, size_t L = 60) {
8181

8282
size_t sg_size = sgsizeacc[0];
8383
int SGid = 0;
84+
int SGLid = 0;
85+
int SGBeginGid = 0;
8486
for (int j = 0; j < G; j++) {
8587
if (j % L % sg_size == 0) {
8688
SGid++;
89+
SGLid = 0;
90+
SGBeginGid = j;
8791
}
8892
if (j % L == 0) {
8993
SGid = 0;
94+
SGLid = 0;
95+
SGBeginGid = j;
9096
}
9197
/*GID of middle element in every subgroup*/
9298
exit_if_not_equal_vec<T, N>(
@@ -115,17 +121,19 @@ void check(queue &Queue, size_t G = 240, size_t L = 60) {
115121
exit_if_not_equal_vec(acc2_up[j], vec<T, N>(j - SGid + sg_size),
116122
"shuffle2_up");
117123
}
118-
/* GID XOR SGID */
119-
exit_if_not_equal_vec(acc_xor[j], vec<T, N>(j ^ SGid), "shuffle_xor");
124+
/* Value GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
125+
exit_if_not_equal_vec(acc_xor[j],
126+
vec<T, N>(SGBeginGid + (SGLid ^ (SGid % sg_size))),
127+
"shuffle_xor");
128+
SGLid++;
120129
}
121130
} catch (exception e) {
122131
std::cout << "SYCL exception caught: " << e.what();
123132
exit(1);
124133
}
125134
}
126135

127-
template <typename T>
128-
void check(queue &Queue, size_t G = 240, size_t L = 60) {
136+
template <typename T> void check(queue &Queue, size_t G = 240, size_t L = 60) {
129137
try {
130138
nd_range<1> NdRange(G, L);
131139
buffer<T> buf2(G);
@@ -171,8 +179,9 @@ void check(queue &Queue, size_t G = 240, size_t L = 60) {
171179
acc_up[NdItem.get_global_id()] = SG.shuffle_up<T>(wggid, sgid);
172180
/* Save GID+SGID */
173181
acc_down[NdItem.get_global_id()] = SG.shuffle_down<T>(wggid, sgid);
174-
/* Save GID XOR SGID */
175-
acc_xor[NdItem.get_global_id()] = SG.shuffle_xor<T>(wggid, sgid);
182+
/* Save GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
183+
acc_xor[NdItem.get_global_id()] =
184+
SG.shuffle_xor<T>(wggid, sgid % SG.get_max_local_range()[0]);
176185
});
177186
});
178187
auto acc = buf.template get_access<access::mode::read_write>();
@@ -186,13 +195,20 @@ void check(queue &Queue, size_t G = 240, size_t L = 60) {
186195

187196
size_t sg_size = sgsizeacc[0];
188197
int SGid = 0;
198+
int SGLid = 0;
199+
int SGBeginGid = 0;
189200
for (int j = 0; j < G; j++) {
190201
if (j % L % sg_size == 0) {
191202
SGid++;
203+
SGLid = 0;
204+
SGBeginGid = j;
192205
}
193206
if (j % L == 0) {
194207
SGid = 0;
208+
SGLid = 0;
209+
SGBeginGid = j;
195210
}
211+
196212
/*GID of middle element in every subgroup*/
197213
exit_if_not_equal<T>(acc[j], j / L * L + SGid * sg_size + sg_size / 2,
198214
"shuffle");
@@ -215,8 +231,10 @@ void check(queue &Queue, size_t G = 240, size_t L = 60) {
215231
if (j % L - SGid + sg_size < L) /* Do not go out LG*/
216232
exit_if_not_equal<T>(acc2_up[j], j - SGid + sg_size, "shuffle2_up");
217233
}
218-
/* GID XOR SGID */
219-
exit_if_not_equal<T>(acc_xor[j], j ^ SGid, "shuffle_xor");
234+
/* Value GID with SGLID = ( SGLID XOR SGID ) % SGMaxSize */
235+
exit_if_not_equal<T>(acc_xor[j], SGBeginGid + (SGLid ^ (SGid % sg_size)),
236+
"shuffle_xor");
237+
SGLid++;
220238
}
221239
} catch (exception e) {
222240
std::cout << "SYCL exception caught: " << e.what();

0 commit comments

Comments
 (0)