@@ -22,6 +22,30 @@ namespace torch_xla {
2222class XlaBackendImpl : public torch ::lazy::BackendImplInterface {
2323 public:
2424 XlaBackendImpl () {}
25+
26+ bool InitDefaultDeviceType () {
27+ if (!default_device_type_inited_) {
28+ // GetDefaultDevice will trigger the runtime device init, should
29+ // not do it during class init time.
30+ torch::lazy::BackendDevice default_device = *GetDefaultDevice ();
31+ default_device_type_ = std::make_shared<DeviceType>(
32+ static_cast <XlaDeviceType>(default_device.type ()));
33+ default_device_type_inited_ = true ;
34+ }
35+ return true ;
36+ }
37+
38+ bool InitDefaultDeviceOrdinal () {
39+ if (!default_device_ordinal_inited_) {
40+ // GetDefaultDevice will trigger the runtime device init, should
41+ // not do it during class init time.
42+ torch::lazy::BackendDevice default_device = *GetDefaultDevice ();
43+ default_device_ordinal_ = default_device.ordinal ();
44+ default_device_ordinal_inited_ = true ;
45+ }
46+ return true ;
47+ }
48+
2549 void PrepareToExit () const override { XLA_ERROR () << " Not implemented yet" ; }
2650
2751 void SetRngSeed (size_t seed) const override {
@@ -150,18 +174,31 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
150174
151175 std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType ()
152176 const override {
153- // want to reuse the getDefualtDeviceTypelogic
154- torch::lazy::BackendDevice default_device = * GetDefaultDevice ();
155- return std::make_shared<DeviceType>(
156- static_cast <XlaDeviceType>(default_device. type ())) ;
177+ // lazily init default device type, we only need to init once.
178+ static bool init =
179+ const_cast <XlaBackendImpl*>( this )-> InitDefaultDeviceType ();
180+ return default_device_type_ ;
157181 }
158182
159- at::DeviceType EagerFallbackDeviceType () const override {
160- return at::DeviceType::CPU;
183+ void SetDefaultDeviceType (int8_t type) override {
184+ default_device_type_ =
185+ std::make_shared<DeviceType>(static_cast <XlaDeviceType>(type));
186+ default_device_type_inited_ = true ;
161187 }
162188
163- void SetDefaultDeviceType (std::string type) override {
164- default_device_type_ = XlaDeviceType (c10::Device (type).type ());
189+ int64_t GetDefaultDeviceOrdinal () const override {
190+ // lazily init default device ordinal, we only need to init once.
191+ static bool init =
192+ const_cast <XlaBackendImpl*>(this )->InitDefaultDeviceOrdinal ();
193+ return default_device_ordinal_;
194+ }
195+ void SetDefaultDeviceOrdinal (int64_t ordinal) override {
196+ default_device_ordinal_ = ordinal;
197+ default_device_ordinal_inited_ = true ;
198+ }
199+
200+ at::DeviceType EagerFallbackDeviceType () const override {
201+ return at::DeviceType::CPU;
165202 }
166203
167204 std::vector<torch::lazy::BackendDevice> GetBackendDevices () const override {
@@ -180,7 +217,10 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
180217 }
181218
182219 private:
183- DeviceType default_device_type_;
220+ bool default_device_type_inited_ = false ;
221+ bool default_device_ordinal_inited_ = false ;
222+ std::shared_ptr<torch::lazy::BackendDeviceType> default_device_type_;
223+ int64_t default_device_ordinal_;
184224};
185225
186226torch::lazy::BackendImplInterface* GetXlaBackendImpl () {
0 commit comments