@@ -84,15 +84,15 @@ torch::lazy::hash_t GetOperandHashes(const OpList& operands,
8484
8585} // namespace
8686
87- const xla::Shape& XlaValue::xla_shape () const {
88- XlaNode* casted = dynamic_cast <XlaNode*>(node.get ());
89- return casted->xla_shape (index);
90- }
87+ // const xla::Shape& XlaValue::xla_shape() const {
88+ // XlaNode* casted = dynamic_cast<XlaNode*>(node.get());
89+ // return casted->xla_shape(index);
90+ // }
9191
92- const xla::Shape& XlaValue::xla_node_shape () const {
93- XlaNode* casted = dynamic_cast <XlaNode*>(node.get ());
94- return casted->xla_shape ();
95- }
92+ // const xla::Shape& XlaValue::xla_node_shape() const {
93+ // XlaNode* casted = dynamic_cast<XlaNode*>(node.get());
94+ // return casted->xla_shape();
95+ // }
9696
9797XlaNode::XlaNode (torch::lazy::OpKind op, OpList operands,
9898 std::vector<torch::lazy::Shape>&& shapes, xla::Shape xla_shape,
@@ -102,7 +102,7 @@ XlaNode::XlaNode(torch::lazy::OpKind op, OpList operands,
102102 node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)),
103103 dag_hash_(GetOperandHashes(operands, node_hash_)) {
104104 // We have to call AddOperand here since upstream OpList is
105- // an array of torch::lazy::Value while we uses XlaValue .
105+ // an array of torch::lazy::Value while we uses torch::lazy::Value .
106106 for (auto & operand : operands) {
107107 AddOperand (operand.node , operand.index );
108108 }
@@ -116,7 +116,7 @@ XlaNode::XlaNode(torch::lazy::OpKind op, OpList operands,
116116 node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)),
117117 dag_hash_(GetOperandHashes(operands, node_hash_)) {
118118 // We have to call AddOperand here since upstream OpList is
119- // an array of torch::lazy::Value while we uses XlaValue .
119+ // an array of torch::lazy::Value while we uses torch::lazy::Value .
120120 for (auto & operand : operands) {
121121 AddOperand (operand.node , operand.index );
122122 }
@@ -131,7 +131,7 @@ XlaNode::XlaNode(torch::lazy::OpKind op, OpList operands,
131131 node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)),
132132 dag_hash_(GetOperandHashes(operands, node_hash_)) {
133133 // We have to call AddOperand here since upstream OpList is
134- // an array of torch::lazy::Value while we uses XlaValue .
134+ // an array of torch::lazy::Value while we uses torch::lazy::Value .
135135 for (auto & operand : operands) {
136136 AddOperand (operand.node , operand.index );
137137 }
@@ -233,4 +233,9 @@ ScopePusher::~ScopePusher() { PopScope(); }
233233
234234void ScopePusher::ResetScopes () { ResetScopeContext (); }
235235
236+ const xla::Shape& GetXlaShape (const torch::lazy::Value& value) {
237+ XlaNode* casted = dynamic_cast <XlaNode*>(value.node .get ());
238+ return casted->xla_shape ();
239+ }
240+
236241} // namespace torch_xla
0 commit comments