Skip to content

Commit 9fa40f6

Browse files
authored
Default device backendintf (#3695)
* Implement default device type and ordinal for backend interface * torch_pin * lazy init defailt device * init xlaBackend in test * Delete .torch_pin
1 parent fd4151b commit 9fa40f6

File tree

2 files changed

+51
-17
lines changed

2 files changed

+51
-17
lines changed

test/cpp/torch_xla_test.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,13 @@
88
#include "torch_xla/csrc/device.h"
99
#include "torch_xla/csrc/helpers.h"
1010
#include "torch_xla/csrc/tensor.h"
11-
12-
namespace at {
13-
// This function is defined in the codegenerated RegisterDispatchKey.cpp file.
14-
extern TORCH_API void RegisterXLAXLANativeFunctions();
15-
extern TORCH_API void RegisterXLAAutogradXLANativeFunctions();
16-
} // namespace at
11+
#include "torch_xla/csrc/xla_backend_impl.h"
1712

1813
namespace torch_xla {
1914
namespace cpp_test {
2015

2116
void XlaTest::SetUp() {
22-
at::RegisterXLAXLANativeFunctions();
23-
at::RegisterXLAAutogradXLANativeFunctions();
17+
InitXlaBackend();
2418
at::manual_seed(42);
2519
XLATensor::SetRngSeed(GetCurrentDevice(), 42);
2620
start_msnap_ = absl::make_unique<MetricsSnapshot>();

torch_xla/csrc/xla_backend_impl.cpp

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,30 @@ namespace torch_xla {
2222
class XlaBackendImpl : public torch::lazy::BackendImplInterface {
2323
public:
2424
XlaBackendImpl() {}
25+
26+
bool InitDefaultDeviceType() {
27+
if (!default_device_type_inited_) {
28+
// GetDefaultDevice will trigger the runtime device init, should
29+
// not do it during class init time.
30+
torch::lazy::BackendDevice default_device = *GetDefaultDevice();
31+
default_device_type_ = std::make_shared<DeviceType>(
32+
static_cast<XlaDeviceType>(default_device.type()));
33+
default_device_type_inited_ = true;
34+
}
35+
return true;
36+
}
37+
38+
bool InitDefaultDeviceOrdinal() {
39+
if (!default_device_ordinal_inited_) {
40+
// GetDefaultDevice will trigger the runtime device init, should
41+
// not do it during class init time.
42+
torch::lazy::BackendDevice default_device = *GetDefaultDevice();
43+
default_device_ordinal_ = default_device.ordinal();
44+
default_device_ordinal_inited_ = true;
45+
}
46+
return true;
47+
}
48+
2549
void PrepareToExit() const override { XLA_ERROR() << "Not implemented yet"; }
2650

2751
void SetRngSeed(size_t seed) const override {
@@ -150,18 +174,31 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
150174

151175
std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType()
152176
const override {
153-
// want to reuse the getDefualtDeviceTypelogic
154-
torch::lazy::BackendDevice default_device = *GetDefaultDevice();
155-
return std::make_shared<DeviceType>(
156-
static_cast<XlaDeviceType>(default_device.type()));
177+
// lazily init default device type, we only need to init once.
178+
static bool init =
179+
const_cast<XlaBackendImpl*>(this)->InitDefaultDeviceType();
180+
return default_device_type_;
157181
}
158182

159-
at::DeviceType EagerFallbackDeviceType() const override {
160-
return at::DeviceType::CPU;
183+
void SetDefaultDeviceType(int8_t type) override {
184+
default_device_type_ =
185+
std::make_shared<DeviceType>(static_cast<XlaDeviceType>(type));
186+
default_device_type_inited_ = true;
161187
}
162188

163-
void SetDefaultDeviceType(std::string type) override {
164-
default_device_type_ = XlaDeviceType(c10::Device(type).type());
189+
int64_t GetDefaultDeviceOrdinal() const override {
190+
// lazily init default device ordinal, we only need to init once.
191+
static bool init =
192+
const_cast<XlaBackendImpl*>(this)->InitDefaultDeviceOrdinal();
193+
return default_device_ordinal_;
194+
}
195+
void SetDefaultDeviceOrdinal(int64_t ordinal) override {
196+
default_device_ordinal_ = ordinal;
197+
default_device_ordinal_inited_ = true;
198+
}
199+
200+
at::DeviceType EagerFallbackDeviceType() const override {
201+
return at::DeviceType::CPU;
165202
}
166203

167204
std::vector<torch::lazy::BackendDevice> GetBackendDevices() const override {
@@ -180,7 +217,10 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
180217
}
181218

182219
private:
183-
DeviceType default_device_type_;
220+
bool default_device_type_inited_ = false;
221+
bool default_device_ordinal_inited_ = false;
222+
std::shared_ptr<torch::lazy::BackendDeviceType> default_device_type_;
223+
int64_t default_device_ordinal_;
184224
};
185225

186226
torch::lazy::BackendImplInterface* GetXlaBackendImpl() {

0 commit comments

Comments
 (0)