Skip to content

Commit ec8a744

Browse files
authored
Replace XlaValue with upstream lazy::Value (#3572)
* Replace XlaValue with upstream lazy::Value * Update tests * Remove comments * Add index to shape getter
1 parent 82331cf commit ec8a744

File tree

282 files changed

+1887
-1596
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

282 files changed

+1887
-1596
lines changed

scripts/gen_lazy_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,5 @@ def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
7979
create_tensor="XLATensor::Create",
8080
create_aten_from_ltc_tensor="torch_xla::bridge::AtenFromXlaTensor",
8181
tuple_aten_from_ltc_tensors="torch::lazy::TupleAtenFromLtcTensors",
82-
lazy_value_class="torch_xla::XlaValue",
8382
lazy_tensor_ptr="torch_xla::XLATensorPtr",
8483
get_device_fn="torch_xla::bridge::GetXlaDevice")

test/cpp/cpp_test_util.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,14 +250,14 @@ std::string GetTensorHloGraph(at::Tensor tensor) {
250250
return DumpUtil::ToHlo({xtensor.GetIrValue()}, xtensor.GetDevice());
251251
}
252252

253-
XlaValue GetTensorIrValue(const at::Tensor& tensor,
254-
const torch::lazy::BackendDevice& device) {
253+
torch::lazy::Value GetTensorIrValue(const at::Tensor& tensor,
254+
const torch::lazy::BackendDevice& device) {
255255
xla::ComputationClient::DataPtr data = TensorToXlaData(tensor, device);
256256
return torch::lazy::MakeNode<DeviceData>(std::move(data));
257257
}
258258

259259
std::vector<xla::ComputationClient::DataPtr> Execute(
260-
absl::Span<const XlaValue> roots,
260+
absl::Span<const torch::lazy::Value> roots,
261261
const torch::lazy::BackendDevice& device) {
262262
LoweringContext lowering_ctx("Execute", device);
263263
for (auto node : roots) {
@@ -300,7 +300,7 @@ std::vector<at::Tensor> Fetch(
300300
}
301301

302302
std::vector<at::Tensor> ExecuteAndFetch(
303-
absl::Span<const XlaValue> roots,
303+
absl::Span<const torch::lazy::Value> roots,
304304
const torch::lazy::BackendDevice& device) {
305305
auto results = Execute(roots, device);
306306
return Fetch(results);

test/cpp/cpp_test_util.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,19 @@ std::string GetTensorDotGraph(at::Tensor tensor);
8585

8686
std::string GetTensorHloGraph(at::Tensor tensor);
8787

88-
XlaValue GetTensorIrValue(const at::Tensor& tensor,
89-
const torch::lazy::BackendDevice& device);
88+
torch::lazy::Value GetTensorIrValue(const at::Tensor& tensor,
89+
const torch::lazy::BackendDevice& device);
9090

9191
std::vector<xla::ComputationClient::DataPtr> Execute(
92-
absl::Span<const XlaValue> roots, const torch::lazy::BackendDevice& device);
92+
absl::Span<const torch::lazy::Value> roots,
93+
const torch::lazy::BackendDevice& device);
9394

9495
std::vector<at::Tensor> Fetch(
9596
absl::Span<const xla::ComputationClient::DataPtr> device_data);
9697

9798
std::vector<at::Tensor> ExecuteAndFetch(
98-
absl::Span<const XlaValue> roots, const torch::lazy::BackendDevice& device);
99+
absl::Span<const torch::lazy::Value> roots,
100+
const torch::lazy::BackendDevice& device);
99101

100102
void AssertBackward(const torch::Tensor& xla_output,
101103
const std::vector<torch::Tensor>& xla_inputs,

test/cpp/test_ir.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,24 @@ TEST(IrTest, TestScalarCreate) {
2020
TEST(IrTest, TestHash) {
2121
torch::lazy::NodePtr scalar1 = ScalarOp(1.0, xla::F32);
2222
torch::lazy::NodePtr scalar2 = ScalarOp(2.0, xla::F32);
23-
XlaValue add1 = XlaValue(scalar1, 0) + XlaValue(scalar2, 0);
23+
torch::lazy::Value add1 =
24+
torch::lazy::Value(scalar1, 0) + torch::lazy::Value(scalar2, 0);
2425

2526
torch::lazy::NodePtr scalar3 = ScalarOp(1.0, xla::F32);
2627
torch::lazy::NodePtr scalar4 = ScalarOp(2.0, xla::F32);
27-
XlaValue add2 = XlaValue(scalar3, 0) + XlaValue(scalar4, 0);
28+
torch::lazy::Value add2 =
29+
torch::lazy::Value(scalar3, 0) + torch::lazy::Value(scalar4, 0);
2830

2931
torch::lazy::NodePtr scalar5 = ScalarOp(11.0, xla::F32);
3032
torch::lazy::NodePtr scalar6 = ScalarOp(22.0, xla::F32);
31-
XlaValue add3 = XlaValue(scalar5, 0) + XlaValue(scalar6, 0);
33+
torch::lazy::Value add3 =
34+
torch::lazy::Value(scalar5, 0) + torch::lazy::Value(scalar6, 0);
3235

3336
EXPECT_EQ(add1->hash(), add2->hash());
3437
EXPECT_NE(add1->hash(), add3->hash());
3538

36-
XlaValue sub = XlaValue(scalar1, 0) - XlaValue(scalar2, 0);
39+
torch::lazy::Value sub =
40+
torch::lazy::Value(scalar1, 0) - torch::lazy::Value(scalar2, 0);
3741

3842
EXPECT_NE(add1->hash(), sub->hash());
3943
}
@@ -43,9 +47,10 @@ TEST(IrTest, TestSelectUnselect) {
4347
at::Tensor a =
4448
at::rand({4, 16, 3}, at::TensorOptions(at::kFloat)).abs() + 1.0;
4549

46-
XlaValue v_a = GetTensorIrValue(a, device);
47-
XlaValue v_s = torch::lazy::MakeNode<Select>(v_a, /*dim=*/1, /*start=*/3,
48-
/*end=*/14, /*stride=*/3);
50+
torch::lazy::Value v_a = GetTensorIrValue(a, device);
51+
torch::lazy::Value v_s =
52+
torch::lazy::MakeNode<Select>(v_a, /*dim=*/1, /*start=*/3,
53+
/*end=*/14, /*stride=*/3);
4954

5055
auto results = ExecuteAndFetch({v_s}, device);
5156
at::Tensor b =
@@ -54,8 +59,8 @@ TEST(IrTest, TestSelectUnselect) {
5459

5560
// Paste zeros back into the selected view.
5661
at::Tensor z = at::zeros_like(b);
57-
XlaValue v_z = GetTensorIrValue(z, device);
58-
XlaValue v_u =
62+
torch::lazy::Value v_z = GetTensorIrValue(z, device);
63+
torch::lazy::Value v_u =
5964
torch::lazy::MakeNode<Unselect>(v_a, v_z, /*dim=*/1, /*start=*/3,
6065
/*end=*/14, /*stride=*/3);
6166
results = ExecuteAndFetch({v_u}, device);

test/cpp/test_op_by_op_executor.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ TEST(OpByOpExecutorTest, TestSimpleAdd) {
1717
at::Tensor b = at::rand({4, 16, 3}, at::TensorOptions(at::kFloat));
1818
at::Tensor c = a + b;
1919

20-
XlaValue v_a = GetTensorIrValue(a, device);
21-
XlaValue v_b = GetTensorIrValue(b, device);
22-
XlaValue v_c = v_a + v_b;
20+
torch::lazy::Value v_a = GetTensorIrValue(a, device);
21+
torch::lazy::Value v_b = GetTensorIrValue(b, device);
22+
torch::lazy::Value v_c = v_a + v_b;
2323

2424
auto results_data =
2525
OpByOpExecutor::Get()->Execute({v_c}, device.toString(), {});
@@ -35,10 +35,10 @@ TEST(OpByOpExecutorTest, TestStack) {
3535
at::Tensor b = at::rand({4, 8, 3}, at::TensorOptions(at::kFloat));
3636
at::Tensor c = at::stack({a, b}, 1);
3737

38-
XlaValue v_a = GetTensorIrValue(a, device);
39-
XlaValue v_b = GetTensorIrValue(b, device);
40-
XlaValue v_c =
41-
torch::lazy::MakeNode<Stack>(std::vector<XlaValue>({v_a, v_b}), 1);
38+
torch::lazy::Value v_a = GetTensorIrValue(a, device);
39+
torch::lazy::Value v_b = GetTensorIrValue(b, device);
40+
torch::lazy::Value v_c = torch::lazy::MakeNode<Stack>(
41+
std::vector<torch::lazy::Value>({v_a, v_b}), 1);
4242

4343
auto results_data =
4444
OpByOpExecutor::Get()->Execute({v_c}, device.toString(), {});
@@ -54,10 +54,10 @@ TEST(OpByOpExecutorTest, TestAsyncStack) {
5454
at::Tensor b = at::rand({4, 8, 3}, at::TensorOptions(at::kFloat));
5555
at::Tensor c = at::stack({a, b}, 1);
5656

57-
XlaValue v_a = GetTensorIrValue(a, device);
58-
XlaValue v_b = GetTensorIrValue(b, device);
59-
XlaValue v_c =
60-
torch::lazy::MakeNode<Stack>(std::vector<XlaValue>({v_a, v_b}), 1);
57+
torch::lazy::Value v_a = GetTensorIrValue(a, device);
58+
torch::lazy::Value v_b = GetTensorIrValue(b, device);
59+
torch::lazy::Value v_c = torch::lazy::MakeNode<Stack>(
60+
std::vector<torch::lazy::Value>({v_a, v_b}), 1);
6161

6262
auto async =
6363
OpByOpExecutor::Get()->ExecuteAsync({v_c}, device.toString(), {});

torch_xla/csrc/debug_util.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ std::string DebugUtil::GetTensorsGraphInfo(absl::Span<const XLATensor> tensors,
5555
const std::vector<size_t>* indices,
5656
GraphFormat format) {
5757
std::vector<const torch::lazy::Node*> root_nodes;
58-
std::vector<XlaValue> root_values;
58+
std::vector<torch::lazy::Value> root_values;
5959
std::vector<torch::lazy::hash_t> root_hashes;
6060
xla::util::Unique<torch::lazy::BackendDevice> unique_device;
6161
if (indices != nullptr) {
6262
for (auto index : *indices) {
6363
const XLATensor& tensor = tensors[index];
64-
XlaValue ir_value = tensor.CurrentIrValue();
64+
torch::lazy::Value ir_value = tensor.CurrentIrValue();
6565
if (ir_value) {
6666
root_nodes.push_back(ir_value.node.get());
6767
root_hashes.push_back(ir_value.hash());
@@ -71,7 +71,7 @@ std::string DebugUtil::GetTensorsGraphInfo(absl::Span<const XLATensor> tensors,
7171
}
7272
} else {
7373
for (auto& tensor : tensors) {
74-
XlaValue ir_value = tensor.CurrentIrValue();
74+
torch::lazy::Value ir_value = tensor.CurrentIrValue();
7575
if (ir_value) {
7676
root_nodes.push_back(ir_value.node.get());
7777
root_hashes.push_back(ir_value.hash());

0 commit comments

Comments
 (0)