11#include " tensorflow/compiler/xla/xla_client/debug_macros.h"
22#include " torch/csrc/lazy/core/ir.h"
33#include " torch/csrc/lazy/core/ir_builder.h"
4+ #include " torch_xla/csrc/device.h"
45#include " torch_xla/csrc/ops/as_strided.h"
56#include " torch_xla/csrc/ops/cast.h"
7+ #include " torch_xla/csrc/ops/device_data.h"
68#include " torch_xla/csrc/ops/diagonal.h"
79#include " torch_xla/csrc/ops/expand.h"
810#include " torch_xla/csrc/ops/generic.h"
911#include " torch_xla/csrc/ops/ops.h"
12+ #include " torch_xla/csrc/tensor_util.h"
1013
1114namespace torch_xla {
1215
13- struct XLAIrBuilder : IrBuilder {
16+ struct XLAIrBuilder : torch::lazy:: IrBuilder {
1417 torch::lazy::NodePtr MakeDeviceData (
15- const std::shared_ptr<BackendData>& data) const override {
18+ const std::shared_ptr<torch::lazy:: BackendData>& data) const override {
1619 return torch::lazy::MakeNode<DeviceData>(data);
1720 }
1821
@@ -30,27 +33,31 @@ struct XLAIrBuilder : IrBuilder {
3033 torch::lazy::NodePtr MakeView (
3134 const torch::lazy::Value& input0,
3235 const std::vector<int64_t >& output_size) const override {
33- return torch::lazy::MakeNode<ViewOp>(input0, output_size);
36+ // TODO(JackCAoG): use functionization pass instead
37+ return nullptr ;
3438 }
3539 torch::lazy::NodePtr MakeCast (const torch::lazy::Value& input0,
3640 const at::ScalarType& dtype,
3741 const c10::optional<at::ScalarType>& stype =
3842 c10::nullopt ) const override {
3943 return torch::lazy::MakeNode<Cast>(input0, dtype, stype);
4044 }
41- torch::lazy::NodePtr MakeTensorList (const OpList& inputs) const override {
45+ torch::lazy::NodePtr MakeTensorList (
46+ const torch::lazy::OpList& inputs) const override {
4247 // TODO(JackCaoG): implement tensorList IR. This is used by codegen.
4348 XLA_ERROR () << " Need to implement" ;
4449 return nullptr ;
4550 }
4651 // Generic needs cleanup
4752 torch::lazy::NodePtr MakeGeneric (
48- const OpKind& op, const OpList& operands, const Shape& shape ,
49- const size_t & num_outputs = 1 ,
50- const hash_t & hash_seed =
53+ const torch::lazy:: OpKind& op, const torch::lazy:: OpList& operands,
54+ const torch::lazy::Shape& shape, const size_t & num_outputs = 1 ,
55+ const torch::lazy:: hash_t & hash_seed =
5156 static_cast <uint32_t >(0x5a2d296e9 )) const override {
52- return torch::lazy::MakeNode<Generic>(op, operands, shape, num_outputs,
53- hash_seed);
57+ // TODO(JackCaoG): ltc generic op does not take lowering function
58+ // return torch::lazy::MakeNode<Generic>(
59+ // op, operands, MakeXlaShapeFromLazyShape(shape, *GetDefaultDevice()),
60+ // num_outputs, hash_seed);
5461 }
5562
5663 // We should use functionization pass for view ops when migrating to the LTC.
0 commit comments