Skip to content

Commit 75ac08b

Browse files
authored
Add XLABackendIntf test (#3697)
* Add tensor transfer test * Scalar transfer test * add e2e test for xlabckend intf * initXlaBackend in cppTest
1 parent b2f9d01 commit 75ac08b

File tree

4 files changed

+95
-6
lines changed

4 files changed

+95
-6
lines changed

test/cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ set(TORCH_XLA_TEST_SOURCES
5858
test_tensor.cpp
5959
test_xla_util_cache.cpp
6060
torch_xla_test.cpp
61+
test_xla_backend_intf.cpp
6162
)
6263

6364
add_executable(test_ptxla ${TORCH_XLA_TEST_SOURCES})

test/cpp/test_xla_backend_intf.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include <vector>
2+
3+
#include "cpp_test_util.h"
4+
#include "torch_xla/csrc/tensor_util.h"
5+
6+
namespace torch_xla {
7+
namespace cpp_test {
8+
9+
TEST(XLABackendTest, TestTensorTransfer) {
10+
torch::lazy::BackendImplInterface* impl = GetXlaBackendImpl();
11+
at::Tensor input = at::randint(std::numeric_limits<uint8_t>::min(),
12+
std::numeric_limits<uint8_t>::max(), {2, 2},
13+
at::TensorOptions(at::kByte));
14+
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
15+
torch::lazy::BackendDataPtr data = impl->MakeComputationDataFromTensor(
16+
input, torch::lazy::Shape(input.scalar_type(), input.sizes()), device);
17+
at::Tensor res = impl->MakeTensorFromComputationData(data, at::kByte);
18+
AllClose(input, res);
19+
});
20+
}
21+
22+
TEST(XLABackendTest, TestScalarTransfer) {
23+
torch::lazy::BackendImplInterface* impl = GetXlaBackendImpl();
24+
at::Scalar input = at::Scalar(int64_t(1));
25+
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
26+
torch::lazy::BackendDataPtr data =
27+
impl->MakeComputationDataFromScalar(input, device);
28+
at::Tensor res = impl->MakeTensorFromComputationData(data, at::kByte);
29+
AllClose(at::ones({}, at::TensorOptions(at::kByte)), res);
30+
});
31+
}
32+
33+
TEST(XLABackendTest, TestPlaceholder) {
34+
torch::lazy::BackendImplInterface* impl = GetXlaBackendImpl();
35+
torch::lazy::Shape shape(at::kFloat, {10, 10});
36+
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
37+
torch::lazy::BackendDataPtr data =
38+
impl->CreateDataPlaceholder(device, shape);
39+
xla::ComputationClient::DataPtr computation_data = UnwrapXlaData(data);
40+
EXPECT_EQ(computation_data->device(), device.toString());
41+
EXPECT_EQ(computation_data->shape(),
42+
MakeXlaShapeFromLazyShape(shape, device));
43+
});
44+
}
45+
46+
xla::XlaComputation CreateAddComputation(const xla::Shape& shape) {
47+
xla::XlaBuilder builder("AddComputation");
48+
xla::XlaOp x = xla::Parameter(&builder, 0, shape, "x");
49+
xla::XlaOp y = xla::Parameter(&builder, 1, shape, "y");
50+
xla::XlaOp sum = xla::Add(x, y);
51+
return ConsumeValue(builder.Build());
52+
}
53+
54+
TEST(XLABackendTest, TestE2E) {
55+
torch::lazy::BackendImplInterface* impl = GetXlaBackendImpl();
56+
xla::Shape input_shape =
57+
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {8, 8});
58+
at::Tensor one = at::ones({8, 8}, at::TensorOptions(at::kFloat));
59+
std::vector<at::Tensor> tensors = {one, one};
60+
61+
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
62+
xla::XlaComputation xla_computation = CreateAddComputation(input_shape);
63+
torch::lazy::ComputationPtr computation =
64+
std::make_shared<torch_xla::Computation>(
65+
"test", std::move(xla_computation), device);
66+
std::vector<torch::lazy::ComputationPtr> compiled_programs =
67+
impl->Compile({computation});
68+
EXPECT_EQ(compiled_programs.size(), 1);
69+
70+
std::vector<torch::lazy::BackendDataPtr> parameters;
71+
for (auto& tensor : tensors) {
72+
parameters.push_back(impl->MakeComputationDataFromTensor(
73+
tensor, torch::lazy::Shape(tensor.scalar_type(), tensor.sizes()),
74+
device));
75+
}
76+
std::vector<torch::lazy::BackendDataPtr> res =
77+
impl->ExecuteComputation(compiled_programs[0], parameters, device);
78+
EXPECT_EQ(res.size(), 1);
79+
at::Tensor res_tensor =
80+
impl->MakeTensorFromComputationData(res[0], at::kFloat);
81+
AllClose(one + one, res_tensor);
82+
});
83+
}
84+
85+
} // namespace cpp_test
86+
} // namespace torch_xla

torch_xla/csrc/xla_backend_impl.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
105105
// c10::ArrayRef<torch::lazy::Node*> instead of
106106
// c10::ArrayRef<const torch::lazy::Node*> since c10::ArrayRef already
107107
// provided const for its member.
108+
XLA_ERROR() << "Need to handle post_order";
108109
return std::make_unique<LoweringContext>(name, device);
109110
}
110111

@@ -131,15 +132,16 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
131132
std::vector<torch::lazy::ComputationPtr> res;
132133
std::vector<xla::ComputationClient::CompileInstance> compile_instances;
133134
torch::lazy::BackendDevice current_device = GetCurrentDevice();
135+
std::vector<xla::Shape> output_shapes;
134136

135137
for (const torch::lazy::ComputationPtr instance : instances) {
136138
// TODO(JackCaoG): device is missing in instance, use CurrentDevice for
137139
// now
138140
const Computation* torch_xla_computation =
139141
dynamic_cast<Computation*>(instance.get());
140-
xla::Shape shape = MakeShapeWithDeviceLayout(
142+
output_shapes.push_back(MakeShapeWithDeviceLayout(
141143
torch_xla_computation->program_shape().result(),
142-
static_cast<XlaDeviceType>(current_device.type()));
144+
static_cast<XlaDeviceType>(current_device.type())));
143145

144146
// Call GetCompilationDevices and passes all device here if needed.
145147
// Currently on TPU we always have 1 replica per device and one process
@@ -152,7 +154,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
152154
compile_instances.push_back(xla::ComputationClient::CompileInstance(
153155
torch_xla_computation->move_computation(),
154156
torch_xla_computation->get_device_string(),
155-
{current_device.toString()}, &shape));
157+
{current_device.toString()}, &output_shapes.back()));
156158
}
157159
std::vector<std::shared_ptr<xla::ComputationClient::Computation>>
158160
client_computations = xla::ComputationClient::Get()->Compile(
@@ -238,4 +240,4 @@ void InitXlaBackend() {
238240
std::make_unique<torch::lazy::BackendRegistrar>(GetXlaBackendImpl());
239241
};
240242

241-
} // namespace torch_xla
243+
} // namespace torch_xla

torch_xla/csrc/xla_backend_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ class XLAData : public torch::lazy::BackendData {
3939
xla::ComputationClient::DataPtr xla_data_;
4040
};
4141

42-
// torch::lazy::BackendImplInterface* GetXlaBackendImpl();
42+
torch::lazy::BackendImplInterface* GetXlaBackendImpl();
4343

4444
void InitXlaBackend();
4545

46-
} // namespace torch_xla
46+
} // namespace torch_xla

0 commit comments

Comments
 (0)