@@ -81,6 +81,21 @@ ComputationClient::DataPtr PjRtComputationClient::CreateDataPlaceholder(
8181 return std::make_shared<PjRtData>(device, shape);
8282}
8383
84+ std::vector<ComputationClient::DataPtr> PjRtComputationClient::GetDataShards (
85+ ComputationClient::DataPtr data) {
86+ std::vector<ComputationClient::DataPtr> shards;
87+ if (PjRtShardedData* sharded_data =
88+ dynamic_cast <PjRtShardedData*>(data.get ())) {
89+ for (auto shard : sharded_data->shards ) {
90+ shards.push_back (std::make_shared<PjRtData>(
91+ shard->device (), shard->shape (), shard->buffer ));
92+ }
93+ } else {
94+ shards.push_back (data);
95+ }
96+ return shards;
97+ }
98+
8499std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer (
85100 absl::Span<const TensorSource> tensors) {
86101 tensorflow::profiler::TraceMe activity (
@@ -118,6 +133,21 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
118133 return datas;
119134}
120135
136+ ComputationClient::DataPtr PjRtComputationClient::TransferShardsToServer (
137+ absl::Span<const TensorSource> tensor_shards, std::string device,
138+ xla::Shape shape) {
139+ TF_VLOG (1 ) << " TransferShardsToServer with " << tensor_shards.size ()
140+ << " shards." ;
141+ auto data_shards = TransferToServer (tensor_shards);
142+ std::vector<std::shared_ptr<PjRtData>> pjrt_data_shards;
143+ for (auto & shard : data_shards) {
144+ auto pjrt_shard = dynamic_cast <PjRtData*>(shard.get ());
145+ pjrt_data_shards.push_back (std::make_shared<PjRtData>(
146+ pjrt_shard->device (), pjrt_shard->shape (), pjrt_shard->buffer ));
147+ }
148+ return std::make_shared<PjRtShardedData>(device, shape, pjrt_data_shards);
149+ }
150+
121151std::vector<xla::Literal> PjRtComputationClient::TransferFromServer (
122152 absl::Span<const DataPtr> handles) {
123153 tensorflow::profiler::TraceMe activity (
@@ -145,26 +175,49 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
145175 std::vector<ComputationClient::ComputationPtr> computations;
146176
147177 for (auto & instance : instances) {
148- PjRtDevice* pjrt_device = StringToPjRtDevice (instance.compilation_device );
149- xla::ProgramShape program_shape =
150- instance.computation .GetProgramShape ().ValueOrDie ();
151178 xla::CompileOptions compile_options;
152- xla::DeviceAssignment device_assignment (client_->device_count (), 1 );
153- device_assignment.FillIota (0 );
154- compile_options.executable_build_options .set_device_assignment (
155- device_assignment);
156- // TODO(wcromar): set compile_options.argument_layouts, enable strict shapes
157- compile_options.executable_build_options .set_num_partitions (1 );
158- compile_options.executable_build_options .set_num_replicas (
159- client_->device_count ());
160- compile_options.parameter_is_tupled_arguments =
161- instance.parameter_is_tupled_arguments ;
179+ if (instance.is_sharded ) {
180+ // TODO(yeounoh) multi-host, multi-slice configurations
181+ compile_options.executable_build_options .set_use_spmd_partitioning (true );
182+ compile_options.executable_build_options .set_num_partitions (
183+ client_->device_count ());
184+ compile_options.executable_build_options .set_num_replicas (1 );
185+ compile_options.parameter_is_tupled_arguments =
186+ instance.parameter_is_tupled_arguments ;
187+
188+ // TODO(244391366) verify this is correct for the collectives ops
189+ xla::DeviceAssignment device_assignment (1 , client_->device_count ());
190+ device_assignment.FillIota (0 );
191+ compile_options.executable_build_options .set_device_assignment (
192+ device_assignment);
193+ } else {
194+ // TODO(wcromar): set compile_options.argument_layouts, enable strict
195+ // shapes
196+ compile_options.executable_build_options .set_num_partitions (1 );
197+ compile_options.executable_build_options .set_num_replicas (
198+ client_->device_count ());
199+ compile_options.parameter_is_tupled_arguments =
200+ instance.parameter_is_tupled_arguments ;
201+
202+ xla::DeviceAssignment device_assignment (client_->device_count (), 1 );
203+ device_assignment.FillIota (0 );
204+ compile_options.executable_build_options .set_device_assignment (
205+ device_assignment);
206+ }
207+
208+ PjRtDevice* pjrt_device = StringToPjRtDevice (instance.compilation_device );
162209 std::unique_ptr<xla::PjRtExecutable> executable =
163- client_->Compile (instance.computation , compile_options).ValueOrDie ();
210+ ConsumeValue (client_->Compile (instance.computation , compile_options));
211+
212+ const auto & hlo_modules = ConsumeValue (executable->GetHloModules ());
213+ HloComputation* hlo_computation = hlo_modules[0 ]->entry_computation ();
214+ xla::ProgramShape program_shape =
215+ xla::ProgramShape (hlo_computation->ToProto ().program_shape ());
216+
164217 std::shared_ptr<PjRtComputation> pjrt_computation =
165- std::make_shared<PjRtComputation>(std::move (instance. computation ),
166- program_shape, instance. devices ,
167- std::move (executable));
218+ std::make_shared<PjRtComputation>(
219+ std::move ( xla::XlaComputation (hlo_modules[ 0 ]-> ToProto ())) ,
220+ program_shape, instance. devices , std::move (executable));
168221
169222 computations.push_back (pjrt_computation);
170223 }
@@ -222,6 +275,64 @@ PjRtComputationClient::ExecuteComputation(
222275 return datas;
223276}
224277
278+ std::vector<std::vector<ComputationClient::DataPtr>>
279+ PjRtComputationClient::ExecuteReplicated (
280+ const ComputationClient::Computation& computation,
281+ const std::vector<std::vector<ComputationClient::DataPtr>>& arguments,
282+ absl::Span<const std::string> devices,
283+ const ExecuteReplicatedOptions& options) {
284+ const PjRtComputation& pjrt_computation =
285+ dynamic_cast <const PjRtComputation&>(computation);
286+ XLA_CHECK (devices.size () == arguments.size ())
287+ << " ExecuteReplicated over " << devices.size () << " devices, but "
288+ << arguments.size () << " arguments devices." ;
289+
290+ std::vector<std::vector<PjRtBuffer*>> argument_handles;
291+ for (int32_t i = 0 ; i < devices.size (); ++i) {
292+ xla::PjRtDevice* pjrt_device = StringToPjRtDevice (devices[i]);
293+ XLA_CHECK (pjrt_device->IsAddressable ()) << pjrt_device->DebugString ();
294+
295+ std::vector<PjRtBuffer*> buffers;
296+ for (auto & argument : arguments[i]) {
297+ const PjRtData* pjrt_data = dynamic_cast <PjRtData*>(argument.get ());
298+
299+ XLA_CHECK (pjrt_device == pjrt_data->buffer ->device ())
300+ << pjrt_device->DebugString () << " vs "
301+ << pjrt_data->buffer ->device ()->DebugString ();
302+ buffers.push_back (pjrt_data->buffer .get ());
303+ }
304+ argument_handles.push_back (buffers);
305+ }
306+
307+ xla::ExecuteOptions execute_options;
308+ execute_options.untuple_result = options.explode_tuple ;
309+ execute_options.strict_shape_checking = true ;
310+ // TODO(yeounoh) currently only support single-slice execution
311+ execute_options.multi_slice_config = nullptr ;
312+ std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> results =
313+ pjrt_computation.executable ->Execute (argument_handles, execute_options)
314+ .ValueOrDie ();
315+
316+ std::vector<std::vector<ComputationClient::DataPtr>> data_handles (
317+ results.size ());
318+ for (auto & result : results) {
319+ std::vector<ComputationClient::DataPtr> datas (result.size ());
320+ for (int32_t i = 0 ; i < result.size (); ++i) {
321+ std::unique_ptr<xla::PjRtBuffer> buffer = std::move (result[i]);
322+
323+ std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(
324+ devices[i], buffer->logical_on_device_shape ().ValueOrDie (),
325+ std::move (buffer));
326+
327+ datas.push_back (data);
328+ }
329+ data_handles.push_back (datas);
330+ }
331+
332+ TF_VLOG (1 ) << " Returning " << data_handles.size () << " sets of results" ;
333+ return data_handles;
334+ }
335+
225336size_t PjRtComputationClient::GetNumDevices () const {
226337 return client_->addressable_device_count ();
227338}
0 commit comments