Skip to content

Commit f8b3dfd

Browse files
committed
Merge branch 'master' into codegenHardshrink
2 parents e6fb800 + ffa5b34 commit f8b3dfd

21 files changed

+526
-86
lines changed

docker/common.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function run_deployment_tests() {
77

88
# We don't need to load libtpu since test is being done on CPU.
99
time TPU_LOAD_LIBRARY=0 python /pytorch/xla/test/test_train_mp_mnist.py --fake_data
10-
time TPU_LOAD_LIBRARY=0 bash /pytorch/xla/test/run_tests.sh
10+
# time TPU_LOAD_LIBRARY=0 bash /pytorch/xla/test/run_tests.sh
1111
# TODO(JackCaoG): reenable after fixing the cpp test build
1212
# time bash /pytorch/xla/test/cpp/run_tests.sh
1313
}

test/cpp/test_symint.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,51 @@ TEST(SymintTest, TestDynamicSymints) {
122122
EXPECT_EQ(si_element_size_nodes, size_nodes);
123123
}
124124

125+
TEST(SymintTest, TestDynamicSymintArithmetic) {
126+
torch::lazy::Value scalar_value =
127+
torch::lazy::Value(ScalarOp(1.0, xla::F32), 0);
128+
129+
std::vector<int64_t> target_size = {10, 20, 30};
130+
torch::lazy::NodePtr expand_node =
131+
torch::lazy::MakeNode<Expand>(scalar_value, target_size);
132+
torch::lazy::Value expand_value = torch::lazy::Value(expand_node, 0);
133+
134+
std::vector<torch::lazy::Shape> abs_lazy_shapes =
135+
std::vector<torch::lazy::Shape>{
136+
torch::lazy::Shape(torch::kFloat, {10, 20, 30})};
137+
138+
std::vector<torch::lazy::Shape> relu_lazy_shapes =
139+
std::vector<torch::lazy::Shape>{
140+
torch::lazy::Shape(torch::kFloat, {10, 20, 30})};
141+
142+
torch::lazy::NodePtr abs_node =
143+
torch::lazy::MakeNode<Abs>(expand_value, std::move(abs_lazy_shapes));
144+
torch::lazy::NodePtr relu_node =
145+
torch::lazy::MakeNode<Relu>(expand_value, std::move(relu_lazy_shapes));
146+
147+
torch::lazy::NodePtr size_abs_node = torch::lazy::MakeNode<SizeNode>(
148+
torch::lazy::Value{abs_node, 0}, /*dim=*/0);
149+
torch::lazy::NodePtr size_relu_node = torch::lazy::MakeNode<SizeNode>(
150+
torch::lazy::Value{relu_node, 0}, /*dim=*/0);
151+
152+
c10::SymInt a =
153+
c10::make_intrusive<XLASymIntNodeImpl>(size_abs_node)->toSymInt();
154+
155+
c10::SymInt b =
156+
c10::make_intrusive<XLASymIntNodeImpl>(size_relu_node)->toSymInt();
157+
158+
c10::SymInt c = a * b;
159+
160+
auto size_mul_symnode =
161+
dynamic_cast<XLASymIntNodeImpl*>(c.toSymIntNodeImpl().get());
162+
ASSERT_TRUE(size_mul_symnode);
163+
164+
auto size_mul =
165+
std::dynamic_pointer_cast<torch_xla::SizeMul>(size_mul_symnode->node());
166+
ASSERT_TRUE(size_mul);
167+
ASSERT_EQ(size_mul->operands().at(0).node, size_abs_node.get());
168+
ASSERT_EQ(size_mul->operands().at(1).node, size_relu_node.get());
169+
}
170+
125171
} // namespace cpp_test
126-
} // namespace torch_xla
172+
} // namespace torch_xla

test/cpp/test_xla_sharding.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,81 @@ TEST_F(XLAShardingTest, ShardTensor) {
6464
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({8, 7, 4}));
6565
}
6666

67+
TEST_F(XLAShardingTest, CreateTensorsData) {
68+
if (xla::sys_util::GetEnvString(xla::env::kEnvPjRtDevice, "") == "") {
69+
GTEST_SKIP() << "`PJRT_DEVICE` is not set.";
70+
}
71+
72+
std::vector<at::Tensor> tensors(2);
73+
std::fill_n(tensors.begin(), tensors.size(),
74+
at::ones({8, 8}, at::TensorOptions(at::kFloat)));
75+
std::vector<std::string> devices(2);
76+
std::fill_n(devices.begin(), devices.size(), GetDefaultDevice()->toString());
77+
std::vector<XLATensor::ShardingSpecPtr> shardings = {
78+
nullptr, std::make_shared<XLATensor::ShardingSpec>(
79+
xla::HloSharding::Replicate().ToProto())};
80+
std::vector<torch::lazy::BackendDataPtr> tensors_data =
81+
CreateTensorsData(tensors, shardings, devices);
82+
83+
// Returns the input without sharding
84+
auto xla_data = dynamic_cast<XLAData*>(tensors_data[0].get())->xla_data();
85+
std::vector<xla::ComputationClient::DataPtr> shards =
86+
xla::ComputationClient::Get()->GetDataShards(xla_data);
87+
EXPECT_EQ(shards.size(), 1);
88+
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(xla_data->shape(),
89+
shards[0]->shape()));
90+
EXPECT_TRUE(
91+
XlaDataValuesEqual(tensors_data[0], WrapXlaData(shards[0]), at::kFloat));
92+
93+
// Returns multiple input shards, replicated
94+
int64_t n_devices = xla::ComputationClient::Get()->GetLocalDevices().size();
95+
if (n_devices > 1) {
96+
auto sharded_xla_data =
97+
dynamic_cast<XLAData*>(tensors_data[1].get())->xla_data();
98+
shards = xla::ComputationClient::Get()->GetDataShards(sharded_xla_data);
99+
EXPECT_EQ(shards.size(), n_devices);
100+
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
101+
shards[0]->shape()));
102+
EXPECT_TRUE(XlaDataValuesEqual(WrapXlaData(shards[0]),
103+
WrapXlaData(shards[1]), at::kFloat));
104+
}
105+
}
106+
107+
TEST_F(XLAShardingTest, InputHandler) {
108+
if ((xla::sys_util::GetEnvString(xla::env::kEnvPjRtDevice, "") == "") ||
109+
(xla::ComputationClient::Get()->GetLocalDevices().size() < 2)) {
110+
GTEST_SKIP()
111+
<< "`PJRT_DEVICE` is not set, with more than 2 local devices, ("
112+
<< xla::ComputationClient::Get()->GetLocalDevices().size()
113+
<< " local devices detected).";
114+
}
115+
116+
std::vector<at::Tensor> tensors(2);
117+
std::fill_n(tensors.begin(), tensors.size(),
118+
at::ones({8, 8}, at::TensorOptions(at::kFloat)));
119+
std::vector<std::string> devices(2);
120+
std::fill_n(devices.begin(), devices.size(), GetDefaultDevice()->toString());
121+
std::vector<XLATensor::ShardingSpecPtr> shardings = {
122+
nullptr, std::make_shared<XLATensor::ShardingSpec>(
123+
xla::HloSharding::Replicate().ToProto())};
124+
std::vector<torch::lazy::BackendDataPtr> tensors_data =
125+
CreateTensorsData(tensors, shardings, devices);
126+
127+
devices = xla::ComputationClient::Get()->GetLocalDevices();
128+
std::vector<xla::ComputationClient::DataPtr> arguments =
129+
UnwrapXlaData(tensors_data);
130+
auto arguments_by_device = ShardingUtil::InputHandler(arguments, devices);
131+
132+
auto arg0_dev0 = arguments_by_device[0][0];
133+
auto arg0_dev1 = arguments_by_device[1][0];
134+
EXPECT_TRUE(XlaDataValuesEqual(WrapXlaData(arg0_dev0), WrapXlaData(arg0_dev1),
135+
at::kFloat));
136+
137+
auto arg1_dev0 = arguments_by_device[0][1];
138+
auto arg1_dev1 = arguments_by_device[1][1];
139+
EXPECT_TRUE(XlaDataValuesEqual(WrapXlaData(arg1_dev0), WrapXlaData(arg1_dev1),
140+
at::kFloat));
141+
}
142+
67143
} // namespace cpp_test
68144
} // namespace torch_xla

test/pytorch_test_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@
191191
'test_random_from_to_xla', # doesn't raise
192192
'test_random_to_xla', # doesn't raise
193193
'test_copy_', # test against complex32 which is nto supported
194+
'test_assertRaisesRegex_ignore_msg_non_native_device_xla', # segfault on wheel sanity test
194195
},
195196

196197
# test_view_ops.py

test/run_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function run_op_tests {
120120
run_pjrt python3 "$CDIR/pjrt/test_experimental_pjrt.py"
121121
run_pjrt python3 "$CDIR/pjrt/test_experimental_tpu.py"
122122
run_pjrt python3 "$CDIR/pjrt/test_ddp.py"
123-
run_pjrt python3 "$CDIR/test_xla_sharding.py"
123+
#run_pjrt python3 "$CDIR/test_xla_sharding.py" # TODO(yeounoh) debug
124124
run_test python3 "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
125125
}
126126

third_party/xla_client/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ tf_cc_shared_object(
5151
}),
5252
visibility = ["//visibility:public"],
5353
deps = [
54-
"computation_client_impl",
54+
":computation_client_impl",
5555
"//tensorflow/compiler/xla:literal_util",
5656
"//tensorflow/compiler/xla/client",
5757
"//tensorflow/compiler/xla/client:global_data",

third_party/xla_client/computation_client.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,21 @@ class ComputationClient {
135135
CompileInstance() = default;
136136
CompileInstance(XlaComputation computation, std::string compilation_device,
137137
std::vector<std::string> devices, const Shape* output_shape,
138-
bool parameter_is_tupled_arguments = false)
138+
bool parameter_is_tupled_arguments = false,
139+
bool is_sharded = false)
139140
: computation(std::move(computation)),
140141
compilation_device(std::move(compilation_device)),
141142
devices(std::move(devices)),
142143
output_shape(output_shape),
143-
parameter_is_tupled_arguments(parameter_is_tupled_arguments) {}
144+
parameter_is_tupled_arguments(parameter_is_tupled_arguments),
145+
is_sharded(is_sharded) {}
144146

145147
XlaComputation computation;
146148
std::string compilation_device;
147149
std::vector<std::string> devices;
148150
const Shape* output_shape = nullptr;
149151
bool parameter_is_tupled_arguments;
152+
bool is_sharded;
150153
};
151154

152155
struct ExecuteOptions {
@@ -208,10 +211,21 @@ class ComputationClient {
208211
virtual std::vector<xla::util::ExceptionCleanup> LockAsyncDatas(
209212
absl::Span<const DataPtr> datas) = 0;
210213

214+
// Returns data shards. We expect this to be called on PjRtShardedData to
215+
// retrieve the shards. If other data type is passed, it returns the input
216+
// wrapped inside a vector.
217+
virtual std::vector<DataPtr> GetDataShards(DataPtr data) = 0;
218+
211219
// Transfers local tensor values to the TPU servers and fetches the handles.
212220
virtual std::vector<DataPtr> TransferToServer(
213221
absl::Span<const TensorSource> tensors) = 0;
214222

223+
// Transfers local sharded tensor values to the TPU servers and returns a
224+
// `PjRtShardedData`.
225+
virtual DataPtr TransferShardsToServer(
226+
absl::Span<const TensorSource> tensor_shards, std::string device,
227+
xla::Shape shape) = 0;
228+
215229
// Transfers local tensor values to the TPU servers and fetches the handles.
216230
// Update the handles associated with DataPtrs passed instead of creating new
217231
// datas.

third_party/xla_client/pjrt_computation_client.cc

Lines changed: 128 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,21 @@ ComputationClient::DataPtr PjRtComputationClient::CreateDataPlaceholder(
8181
return std::make_shared<PjRtData>(device, shape);
8282
}
8383

84+
std::vector<ComputationClient::DataPtr> PjRtComputationClient::GetDataShards(
85+
ComputationClient::DataPtr data) {
86+
std::vector<ComputationClient::DataPtr> shards;
87+
if (PjRtShardedData* sharded_data =
88+
dynamic_cast<PjRtShardedData*>(data.get())) {
89+
for (auto shard : sharded_data->shards) {
90+
shards.push_back(std::make_shared<PjRtData>(
91+
shard->device(), shard->shape(), shard->buffer));
92+
}
93+
} else {
94+
shards.push_back(data);
95+
}
96+
return shards;
97+
}
98+
8499
std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
85100
absl::Span<const TensorSource> tensors) {
86101
tensorflow::profiler::TraceMe activity(
@@ -118,6 +133,21 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
118133
return datas;
119134
}
120135

136+
ComputationClient::DataPtr PjRtComputationClient::TransferShardsToServer(
137+
absl::Span<const TensorSource> tensor_shards, std::string device,
138+
xla::Shape shape) {
139+
TF_VLOG(1) << "TransferShardsToServer with " << tensor_shards.size()
140+
<< " shards.";
141+
auto data_shards = TransferToServer(tensor_shards);
142+
std::vector<std::shared_ptr<PjRtData>> pjrt_data_shards;
143+
for (auto& shard : data_shards) {
144+
auto pjrt_shard = dynamic_cast<PjRtData*>(shard.get());
145+
pjrt_data_shards.push_back(std::make_shared<PjRtData>(
146+
pjrt_shard->device(), pjrt_shard->shape(), pjrt_shard->buffer));
147+
}
148+
return std::make_shared<PjRtShardedData>(device, shape, pjrt_data_shards);
149+
}
150+
121151
std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
122152
absl::Span<const DataPtr> handles) {
123153
tensorflow::profiler::TraceMe activity(
@@ -145,26 +175,49 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
145175
std::vector<ComputationClient::ComputationPtr> computations;
146176

147177
for (auto& instance : instances) {
148-
PjRtDevice* pjrt_device = StringToPjRtDevice(instance.compilation_device);
149-
xla::ProgramShape program_shape =
150-
instance.computation.GetProgramShape().ValueOrDie();
151178
xla::CompileOptions compile_options;
152-
xla::DeviceAssignment device_assignment(client_->device_count(), 1);
153-
device_assignment.FillIota(0);
154-
compile_options.executable_build_options.set_device_assignment(
155-
device_assignment);
156-
// TODO(wcromar): set compile_options.argument_layouts, enable strict shapes
157-
compile_options.executable_build_options.set_num_partitions(1);
158-
compile_options.executable_build_options.set_num_replicas(
159-
client_->device_count());
160-
compile_options.parameter_is_tupled_arguments =
161-
instance.parameter_is_tupled_arguments;
179+
if (instance.is_sharded) {
180+
// TODO(yeounoh) multi-host, multi-slice configurations
181+
compile_options.executable_build_options.set_use_spmd_partitioning(true);
182+
compile_options.executable_build_options.set_num_partitions(
183+
client_->device_count());
184+
compile_options.executable_build_options.set_num_replicas(1);
185+
compile_options.parameter_is_tupled_arguments =
186+
instance.parameter_is_tupled_arguments;
187+
188+
// TODO(244391366) verify this is correct for the collectives ops
189+
xla::DeviceAssignment device_assignment(1, client_->device_count());
190+
device_assignment.FillIota(0);
191+
compile_options.executable_build_options.set_device_assignment(
192+
device_assignment);
193+
} else {
194+
// TODO(wcromar): set compile_options.argument_layouts, enable strict
195+
// shapes
196+
compile_options.executable_build_options.set_num_partitions(1);
197+
compile_options.executable_build_options.set_num_replicas(
198+
client_->device_count());
199+
compile_options.parameter_is_tupled_arguments =
200+
instance.parameter_is_tupled_arguments;
201+
202+
xla::DeviceAssignment device_assignment(client_->device_count(), 1);
203+
device_assignment.FillIota(0);
204+
compile_options.executable_build_options.set_device_assignment(
205+
device_assignment);
206+
}
207+
208+
PjRtDevice* pjrt_device = StringToPjRtDevice(instance.compilation_device);
162209
std::unique_ptr<xla::PjRtExecutable> executable =
163-
client_->Compile(instance.computation, compile_options).ValueOrDie();
210+
ConsumeValue(client_->Compile(instance.computation, compile_options));
211+
212+
const auto& hlo_modules = ConsumeValue(executable->GetHloModules());
213+
HloComputation* hlo_computation = hlo_modules[0]->entry_computation();
214+
xla::ProgramShape program_shape =
215+
xla::ProgramShape(hlo_computation->ToProto().program_shape());
216+
164217
std::shared_ptr<PjRtComputation> pjrt_computation =
165-
std::make_shared<PjRtComputation>(std::move(instance.computation),
166-
program_shape, instance.devices,
167-
std::move(executable));
218+
std::make_shared<PjRtComputation>(
219+
std::move(xla::XlaComputation(hlo_modules[0]->ToProto())),
220+
program_shape, instance.devices, std::move(executable));
168221

169222
computations.push_back(pjrt_computation);
170223
}
@@ -222,6 +275,64 @@ PjRtComputationClient::ExecuteComputation(
222275
return datas;
223276
}
224277

278+
std::vector<std::vector<ComputationClient::DataPtr>>
279+
PjRtComputationClient::ExecuteReplicated(
280+
const ComputationClient::Computation& computation,
281+
const std::vector<std::vector<ComputationClient::DataPtr>>& arguments,
282+
absl::Span<const std::string> devices,
283+
const ExecuteReplicatedOptions& options) {
284+
const PjRtComputation& pjrt_computation =
285+
dynamic_cast<const PjRtComputation&>(computation);
286+
XLA_CHECK(devices.size() == arguments.size())
287+
<< "ExecuteReplicated over " << devices.size() << " devices, but "
288+
<< arguments.size() << " arguments devices.";
289+
290+
std::vector<std::vector<PjRtBuffer*>> argument_handles;
291+
for (int32_t i = 0; i < devices.size(); ++i) {
292+
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]);
293+
XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();
294+
295+
std::vector<PjRtBuffer*> buffers;
296+
for (auto& argument : arguments[i]) {
297+
const PjRtData* pjrt_data = dynamic_cast<PjRtData*>(argument.get());
298+
299+
XLA_CHECK(pjrt_device == pjrt_data->buffer->device())
300+
<< pjrt_device->DebugString() << " vs "
301+
<< pjrt_data->buffer->device()->DebugString();
302+
buffers.push_back(pjrt_data->buffer.get());
303+
}
304+
argument_handles.push_back(buffers);
305+
}
306+
307+
xla::ExecuteOptions execute_options;
308+
execute_options.untuple_result = options.explode_tuple;
309+
execute_options.strict_shape_checking = true;
310+
// TODO(yeounoh) currently only support single-slice execution
311+
execute_options.multi_slice_config = nullptr;
312+
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> results =
313+
pjrt_computation.executable->Execute(argument_handles, execute_options)
314+
.ValueOrDie();
315+
316+
std::vector<std::vector<ComputationClient::DataPtr>> data_handles(
317+
results.size());
318+
for (auto& result : results) {
319+
std::vector<ComputationClient::DataPtr> datas(result.size());
320+
for (int32_t i = 0; i < result.size(); ++i) {
321+
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(result[i]);
322+
323+
std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(
324+
devices[i], buffer->logical_on_device_shape().ValueOrDie(),
325+
std::move(buffer));
326+
327+
datas.push_back(data);
328+
}
329+
data_handles.push_back(datas);
330+
}
331+
332+
TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results";
333+
return data_handles;
334+
}
335+
225336
size_t PjRtComputationClient::GetNumDevices() const {
226337
return client_->addressable_device_count();
227338
}

0 commit comments

Comments
 (0)