Skip to content

Commit 83d4b02

Browse files
committed
Replace XlaValue with upstream lazy::Value
1 parent 041ebf9 commit 83d4b02

File tree

278 files changed

+1877
-1565
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

278 files changed

+1877
-1565
lines changed

scripts/gen_lazy_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,5 @@ def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
7979
create_tensor="XLATensor::Create",
8080
create_aten_from_ltc_tensor="torch_xla::bridge::AtenFromXlaTensor",
8181
tuple_aten_from_ltc_tensors="torch::lazy::TupleAtenFromLtcTensors",
82-
lazy_value_class="torch_xla::XlaValue",
8382
lazy_tensor_ptr="torch_xla::XLATensorPtr",
8483
get_device_fn="torch_xla::bridge::GetXlaDevice")

torch_xla/csrc/debug_util.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ std::string DebugUtil::GetTensorsGraphInfo(absl::Span<const XLATensor> tensors,
5555
const std::vector<size_t>* indices,
5656
GraphFormat format) {
5757
std::vector<const torch::lazy::Node*> root_nodes;
58-
std::vector<XlaValue> root_values;
58+
std::vector<torch::lazy::Value> root_values;
5959
std::vector<torch::lazy::hash_t> root_hashes;
6060
xla::util::Unique<torch::lazy::BackendDevice> unique_device;
6161
if (indices != nullptr) {
6262
for (auto index : *indices) {
6363
const XLATensor& tensor = tensors[index];
64-
XlaValue ir_value = tensor.CurrentIrValue();
64+
torch::lazy::Value ir_value = tensor.CurrentIrValue();
6565
if (ir_value) {
6666
root_nodes.push_back(ir_value.node.get());
6767
root_hashes.push_back(ir_value.hash());
@@ -71,7 +71,7 @@ std::string DebugUtil::GetTensorsGraphInfo(absl::Span<const XLATensor> tensors,
7171
}
7272
} else {
7373
for (auto& tensor : tensors) {
74-
XlaValue ir_value = tensor.CurrentIrValue();
74+
torch::lazy::Value ir_value = tensor.CurrentIrValue();
7575
if (ir_value) {
7676
root_nodes.push_back(ir_value.node.get());
7777
root_hashes.push_back(ir_value.hash());

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 78 additions & 73 deletions
Large diffs are not rendered by default.

torch_xla/csrc/ir.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ torch::lazy::hash_t GetOperandHashes(const OpList& operands,
8484

8585
} // namespace
8686

87-
const xla::Shape& XlaValue::xla_shape() const {
88-
XlaNode* casted = dynamic_cast<XlaNode*>(node.get());
89-
return casted->xla_shape(index);
90-
}
87+
// const xla::Shape& XlaValue::xla_shape() const {
88+
// XlaNode* casted = dynamic_cast<XlaNode*>(node.get());
89+
// return casted->xla_shape(index);
90+
// }
9191

92-
const xla::Shape& XlaValue::xla_node_shape() const {
93-
XlaNode* casted = dynamic_cast<XlaNode*>(node.get());
94-
return casted->xla_shape();
95-
}
92+
// const xla::Shape& XlaValue::xla_node_shape() const {
93+
// XlaNode* casted = dynamic_cast<XlaNode*>(node.get());
94+
// return casted->xla_shape();
95+
// }
9696

9797
XlaNode::XlaNode(torch::lazy::OpKind op, OpList operands,
9898
std::vector<torch::lazy::Shape>&& shapes, xla::Shape xla_shape,
@@ -102,7 +102,7 @@ XlaNode::XlaNode(torch::lazy::OpKind op, OpList operands,
102102
node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)),
103103
dag_hash_(GetOperandHashes(operands, node_hash_)) {
104104
// We have to call AddOperand here since upstream OpList is
105-
// an array of torch::lazy::Value while we uses XlaValue.
105+
// an array of torch::lazy::Value while we uses torch::lazy::Value.
106106
for (auto& operand : operands) {
107107
AddOperand(operand.node, operand.index);
108108
}
@@ -116,7 +116,7 @@ XlaNode::XlaNode(torch::lazy::OpKind op, OpList operands,
116116
node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)),
117117
dag_hash_(GetOperandHashes(operands, node_hash_)) {
118118
// We have to call AddOperand here since upstream OpList is
119-
// an array of torch::lazy::Value while we uses XlaValue.
119+
// an array of torch::lazy::Value while we uses torch::lazy::Value.
120120
for (auto& operand : operands) {
121121
AddOperand(operand.node, operand.index);
122122
}
@@ -131,7 +131,7 @@ XlaNode::XlaNode(torch::lazy::OpKind op, OpList operands,
131131
node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)),
132132
dag_hash_(GetOperandHashes(operands, node_hash_)) {
133133
// We have to call AddOperand here since upstream OpList is
134-
// an array of torch::lazy::Value while we uses XlaValue.
134+
// an array of torch::lazy::Value while we uses torch::lazy::Value.
135135
for (auto& operand : operands) {
136136
AddOperand(operand.node, operand.index);
137137
}
@@ -233,4 +233,9 @@ ScopePusher::~ScopePusher() { PopScope(); }
233233

234234
void ScopePusher::ResetScopes() { ResetScopeContext(); }
235235

236+
const xla::Shape& GetXlaShape(const torch::lazy::Value& value) {
237+
XlaNode* casted = dynamic_cast<XlaNode*>(value.node.get());
238+
return casted->xla_shape();
239+
}
240+
236241
} // namespace torch_xla

torch_xla/csrc/ir.h

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,24 @@ using OutputMap =
3535
std::unordered_map<torch::lazy::Output, T, torch::lazy::Output::Hasher>;
3636

3737
// Represents an input/operand for a XlaNode object.
38-
struct XlaValue : public torch::lazy::Value {
39-
XlaValue() = default;
40-
XlaValue(torch::lazy::NodePtr node, size_t index = 0)
41-
: torch::lazy::Value(std::dynamic_pointer_cast<torch::lazy::Node>(node),
42-
index) {}
43-
44-
// Retrieves the shape of this value. If the IR XlaNode generating the value
45-
// is a multi-output node, the shape returned by this API will not be the full
46-
// tuple shape, but only the shape at index referred by this value.
47-
// To retrieve the full tuple shape in that case, use the node_shape() API.
48-
const xla::Shape& xla_shape() const;
49-
const xla::Shape& xla_node_shape() const;
50-
};
51-
52-
using OpList = absl::Span<const XlaValue>;
38+
// struct XlaValue : public torch::lazy::Value {
39+
// torch::lazy::Value() = default;
40+
// torch::lazy::Value(torch::lazy::NodePtr node, size_t index = 0)
41+
// :
42+
// torch::lazy::Value(std::dynamic_pointer_cast<torch::lazy::Node>(node),
43+
// index) {}
44+
45+
// // Retrieves the shape of this value. If the IR XlaNode generating the
46+
// value
47+
// // is a multi-output node, the shape returned by this API will not be the
48+
// full
49+
// // tuple shape, but only the shape at index referred by this value.
50+
// // To retrieve the full tuple shape in that case, use the node_shape() API.
51+
// // const xla::Shape& xla_shape() const;
52+
// // const xla::Shape& xla_node_shape() const;
53+
// };
54+
55+
using OpList = absl::Span<const torch::lazy::Value>;
5356

5457
// A node in the graph. Nodes for operations which requires extra data to be
5558
// stored for lowering, should inherit from this class and add operation
@@ -159,6 +162,8 @@ inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) {
159162
return stream;
160163
}
161164

165+
const xla::Shape& GetXlaShape(const torch::lazy::Value& value);
166+
162167
template <typename T>
163168
T* NodeCast(const torch::lazy::Node* node, torch::lazy::OpKind op) {
164169
if (op != node->op()) {

torch_xla/csrc/ir_dump_util.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ std::string DumpUtil::PostOrderToText(
244244
return ss.str();
245245
}
246246

247-
std::string DumpUtil::ToHlo(absl::Span<const XlaValue> values,
247+
std::string DumpUtil::ToHlo(absl::Span<const torch::lazy::Value> values,
248248
const torch::lazy::BackendDevice& device) {
249249
LoweringContext lowering_ctx("IrToHlo", device);
250250
for (auto& ir_value : values) {

torch_xla/csrc/ir_dump_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class DumpUtil {
2222
absl::Span<const torch::lazy::Node* const> post_order,
2323
absl::Span<const torch::lazy::Node* const> roots);
2424

25-
static std::string ToHlo(absl::Span<const XlaValue> values,
25+
static std::string ToHlo(absl::Span<const torch::lazy::Value> values,
2626
const torch::lazy::BackendDevice& device);
2727
};
2828

torch_xla/csrc/ir_util.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ std::vector<const torch::lazy::Node*> Util::ComputePostOrder(
5858
return ComputePostOrder(nodes, &emap);
5959
}
6060

61-
std::vector<XlaValue> Util::Clone(
62-
absl::Span<const XlaValue> values,
61+
std::vector<torch::lazy::Value> Util::Clone(
62+
absl::Span<const torch::lazy::Value> values,
6363
absl::Span<const torch::lazy::Node* const> post_order) {
6464
std::unordered_map<const torch::lazy::Node*, torch::lazy::NodePtr> clone_map;
6565
for (auto node : post_order) {
6666
if (clone_map.count(node) > 0) {
6767
continue;
6868
}
69-
std::vector<XlaValue> inputs;
69+
std::vector<torch::lazy::Value> inputs;
7070
for (auto& output : node->operands()) {
7171
auto it = clone_map.find(output.node);
7272
XLA_CHECK(it != clone_map.end())
@@ -77,7 +77,7 @@ std::vector<XlaValue> Util::Clone(
7777
clone_map[node] = casted->Clone(inputs);
7878
}
7979

80-
std::vector<XlaValue> cloned;
80+
std::vector<torch::lazy::Value> cloned;
8181
for (auto& value : values) {
8282
auto it = clone_map.find(value.node.get());
8383
XLA_CHECK(it != clone_map.end()) << "Bad post-order: " << value->ToString();
@@ -86,7 +86,8 @@ std::vector<XlaValue> Util::Clone(
8686
return cloned;
8787
}
8888

89-
std::vector<XlaValue> Util::Clone(absl::Span<const XlaValue> values) {
89+
std::vector<torch::lazy::Value> Util::Clone(
90+
absl::Span<const torch::lazy::Value> values) {
9091
std::vector<const torch::lazy::Node*> nodes;
9192
for (auto& value : values) {
9293
nodes.push_back(value.node.get());

torch_xla/csrc/ir_util.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ class Util {
3131
absl::Span<const torch::lazy::Node* const> nodes);
3232

3333
// Clones the IR graph whose roots are passed in the values parameter.
34-
static std::vector<XlaValue> Clone(absl::Span<const XlaValue> values);
34+
static std::vector<torch::lazy::Value> Clone(
35+
absl::Span<const torch::lazy::Value> values);
3536

3637
// Same as the above, but the post-order is passed as parameter.
37-
static std::vector<XlaValue> Clone(
38-
absl::Span<const XlaValue> values,
38+
static std::vector<torch::lazy::Value> Clone(
39+
absl::Span<const torch::lazy::Value> values,
3940
absl::Span<const torch::lazy::Node* const> post_order);
4041

4142
// Retrieves the number of nodes within the graph whose sink are passed in the

torch_xla/csrc/nms_op.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ NmsResult BuildNms(xla::XlaOp boxes, xla::XlaOp scores,
275275
xla::XlaOp ones_included =
276276
xla::Select(included, xla::Broadcast(one_s32, {num_boxes}),
277277
xla::Broadcast(zero_s32, {num_boxes}));
278-
// num_valid is scalar. XlaValue should be bound by output_size.
278+
// num_valid is scalar. torch::lazy::Value should be bound by output_size.
279279
xla::XlaOp num_valid_total = xla::Reduce(
280280
ones_included,
281281
/*init_value=*/zero_s32,

0 commit comments

Comments
 (0)