Skip to content

Commit fd4151b

Browse files
authored
Implement ir builder (#3696)
* Implement ir builder * typo * add ir_builder header * Fix IrBuilder
1 parent 2cd5f1e commit fd4151b

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

torch_xla/csrc/ir_builder.h

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
22
#include "torch/csrc/lazy/core/ir.h"
33
#include "torch/csrc/lazy/core/ir_builder.h"
4+
#include "torch_xla/csrc/device.h"
45
#include "torch_xla/csrc/ops/as_strided.h"
56
#include "torch_xla/csrc/ops/cast.h"
7+
#include "torch_xla/csrc/ops/device_data.h"
68
#include "torch_xla/csrc/ops/diagonal.h"
79
#include "torch_xla/csrc/ops/expand.h"
810
#include "torch_xla/csrc/ops/generic.h"
911
#include "torch_xla/csrc/ops/ops.h"
12+
#include "torch_xla/csrc/tensor_util.h"
1013

1114
namespace torch_xla {
1215

13-
struct XLAIrBuilder : IrBuilder {
16+
struct XLAIrBuilder : torch::lazy::IrBuilder {
1417
torch::lazy::NodePtr MakeDeviceData(
15-
const std::shared_ptr<BackendData>& data) const override {
18+
const std::shared_ptr<torch::lazy::BackendData>& data) const override {
1619
return torch::lazy::MakeNode<DeviceData>(data);
1720
}
1821

@@ -30,27 +33,31 @@ struct XLAIrBuilder : IrBuilder {
3033
torch::lazy::NodePtr MakeView(
3134
const torch::lazy::Value& input0,
3235
const std::vector<int64_t>& output_size) const override {
33-
return torch::lazy::MakeNode<ViewOp>(input0, output_size);
36+
// TODO(JackCAoG): use functionization pass instead
37+
return nullptr;
3438
}
3539
torch::lazy::NodePtr MakeCast(const torch::lazy::Value& input0,
3640
const at::ScalarType& dtype,
3741
const c10::optional<at::ScalarType>& stype =
3842
c10::nullopt) const override {
3943
return torch::lazy::MakeNode<Cast>(input0, dtype, stype);
4044
}
41-
torch::lazy::NodePtr MakeTensorList(const OpList& inputs) const override {
45+
torch::lazy::NodePtr MakeTensorList(
46+
const torch::lazy::OpList& inputs) const override {
4247
// TODO(JackCaoG): implement tensorList IR. This is used by codegen.
4348
XLA_ERROR() << "Need to implement";
4449
return nullptr;
4550
}
4651
// Generic needs cleanup
4752
torch::lazy::NodePtr MakeGeneric(
48-
const OpKind& op, const OpList& operands, const Shape& shape,
49-
const size_t& num_outputs = 1,
50-
const hash_t& hash_seed =
53+
const torch::lazy::OpKind& op, const torch::lazy::OpList& operands,
54+
const torch::lazy::Shape& shape, const size_t& num_outputs = 1,
55+
const torch::lazy::hash_t& hash_seed =
5156
static_cast<uint32_t>(0x5a2d296e9)) const override {
52-
return torch::lazy::MakeNode<Generic>(op, operands, shape, num_outputs,
53-
hash_seed);
57+
// TODO(JackCaoG): ltc generic op does not take lowering function
58+
// return torch::lazy::MakeNode<Generic>(
59+
// op, operands, MakeXlaShapeFromLazyShape(shape, *GetDefaultDevice()),
60+
// num_outputs, hash_seed);
5461
}
5562

5663
// We should use functionization pass for view ops when migrating to the LTC.

torch_xla/csrc/xla_backend_impl.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "torch_xla/csrc/aten_xla_bridge.h"
77
#include "torch_xla/csrc/computation.h"
88
#include "torch_xla/csrc/device.h"
9+
#include "torch_xla/csrc/ir_builder.h"
910
#include "torch_xla/csrc/lowering_context.h"
1011
#include "torch_xla/csrc/ops/device_data.h"
1112
#include "torch_xla/csrc/tensor.h"
@@ -31,8 +32,8 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
3132
}
3233

3334
const torch::lazy::IrBuilder* GetIrBuilder() const override {
34-
XLA_ERROR() << "Not implemented yet";
35-
return 0;
35+
static const torch::lazy::IrBuilder* builder = new XLAIrBuilder();
36+
return builder;
3637
}
3738

3839
torch::lazy::BackendDataPtr MakeComputationDataFromTensor(

0 commit comments

Comments
 (0)