Skip to content

Commit d5077b5

Browse files
authored
fix: make write_csv/json/parquet cancel-safe (#5196)
* fix: add cancellation token to `write_csv/json/parquet` * refactor: use abort-on-drop instead of cancellation token
1 parent f63b972 commit d5077b5

File tree

3 files changed

+100
-97
lines changed

3 files changed

+100
-97
lines changed

datafusion/core/src/physical_plan/file_format/csv.rs

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use arrow::datatypes::SchemaRef;
3434

3535
use bytes::Buf;
3636

37+
use crate::physical_plan::common::AbortOnDropSingle;
3738
use bytes::Bytes;
3839
use futures::ready;
3940
use futures::{StreamExt, TryStreamExt};
@@ -286,38 +287,39 @@ pub async fn plan_to_csv(
286287
let path = path.as_ref();
287288
// create directory to contain the CSV files (one per partition)
288289
let fs_path = Path::new(path);
289-
match fs::create_dir(fs_path) {
290-
Ok(()) => {
291-
let mut tasks = vec![];
292-
for i in 0..plan.output_partitioning().partition_count() {
293-
let plan = plan.clone();
294-
let filename = format!("part-{i}.csv");
295-
let path = fs_path.join(filename);
296-
let file = fs::File::create(path)?;
297-
let mut writer = csv::Writer::new(file);
298-
let task_ctx = Arc::new(TaskContext::from(state));
299-
let stream = plan.execute(i, task_ctx)?;
300-
let handle: JoinHandle<Result<()>> = task::spawn(async move {
301-
stream
302-
.map(|batch| writer.write(&batch?))
303-
.try_collect()
304-
.await
305-
.map_err(DataFusionError::from)
306-
});
307-
tasks.push(handle);
308-
}
309-
futures::future::join_all(tasks)
310-
.await
311-
.into_iter()
312-
.try_for_each(|result| {
313-
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
314-
})?;
315-
Ok(())
316-
}
317-
Err(e) => Err(DataFusionError::Execution(format!(
290+
if let Err(e) = fs::create_dir(fs_path) {
291+
return Err(DataFusionError::Execution(format!(
318292
"Could not create directory {path}: {e:?}"
319-
))),
293+
)));
294+
}
295+
296+
let mut tasks = vec![];
297+
for i in 0..plan.output_partitioning().partition_count() {
298+
let plan = plan.clone();
299+
let filename = format!("part-{i}.csv");
300+
let path = fs_path.join(filename);
301+
let file = fs::File::create(path)?;
302+
let mut writer = csv::Writer::new(file);
303+
let task_ctx = Arc::new(TaskContext::from(state));
304+
let stream = plan.execute(i, task_ctx)?;
305+
306+
let handle: JoinHandle<Result<()>> = task::spawn(async move {
307+
stream
308+
.map(|batch| writer.write(&batch?))
309+
.try_collect()
310+
.await
311+
.map_err(DataFusionError::from)
312+
});
313+
tasks.push(AbortOnDropSingle::new(handle));
320314
}
315+
316+
futures::future::join_all(tasks)
317+
.await
318+
.into_iter()
319+
.try_for_each(|result| {
320+
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
321+
})?;
322+
Ok(())
321323
}
322324

323325
#[cfg(test)]

datafusion/core/src/physical_plan/file_format/json.rs

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use arrow::{datatypes::SchemaRef, json};
3333

3434
use bytes::{Buf, Bytes};
3535

36+
use crate::physical_plan::common::AbortOnDropSingle;
3637
use arrow::json::RawReaderBuilder;
3738
use futures::{ready, stream, StreamExt, TryStreamExt};
3839
use object_store::{GetResult, ObjectStore};
@@ -230,38 +231,38 @@ pub async fn plan_to_json(
230231
let path = path.as_ref();
231232
// create directory to contain the CSV files (one per partition)
232233
let fs_path = Path::new(path);
233-
match fs::create_dir(fs_path) {
234-
Ok(()) => {
235-
let mut tasks = vec![];
236-
for i in 0..plan.output_partitioning().partition_count() {
237-
let plan = plan.clone();
238-
let filename = format!("part-{i}.json");
239-
let path = fs_path.join(filename);
240-
let file = fs::File::create(path)?;
241-
let mut writer = json::LineDelimitedWriter::new(file);
242-
let task_ctx = Arc::new(TaskContext::from(state));
243-
let stream = plan.execute(i, task_ctx)?;
244-
let handle: JoinHandle<Result<()>> = task::spawn(async move {
245-
stream
246-
.map(|batch| writer.write(batch?))
247-
.try_collect()
248-
.await
249-
.map_err(DataFusionError::from)
250-
});
251-
tasks.push(handle);
252-
}
253-
futures::future::join_all(tasks)
254-
.await
255-
.into_iter()
256-
.try_for_each(|result| {
257-
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
258-
})?;
259-
Ok(())
260-
}
261-
Err(e) => Err(DataFusionError::Execution(format!(
234+
if let Err(e) = fs::create_dir(fs_path) {
235+
return Err(DataFusionError::Execution(format!(
262236
"Could not create directory {path}: {e:?}"
263-
))),
237+
)));
264238
}
239+
240+
let mut tasks = vec![];
241+
for i in 0..plan.output_partitioning().partition_count() {
242+
let plan = plan.clone();
243+
let filename = format!("part-{i}.json");
244+
let path = fs_path.join(filename);
245+
let file = fs::File::create(path)?;
246+
let mut writer = json::LineDelimitedWriter::new(file);
247+
let task_ctx = Arc::new(TaskContext::from(state));
248+
let stream = plan.execute(i, task_ctx)?;
249+
let handle: JoinHandle<Result<()>> = task::spawn(async move {
250+
stream
251+
.map(|batch| writer.write(batch?))
252+
.try_collect()
253+
.await
254+
.map_err(DataFusionError::from)
255+
});
256+
tasks.push(AbortOnDropSingle::new(handle));
257+
}
258+
259+
futures::future::join_all(tasks)
260+
.await
261+
.into_iter()
262+
.try_for_each(|result| {
263+
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
264+
})?;
265+
Ok(())
265266
}
266267

267268
#[cfg(test)]

datafusion/core/src/physical_plan/file_format/parquet.rs

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ mod page_filter;
6666
mod row_filter;
6767
mod row_groups;
6868

69+
use crate::physical_plan::common::AbortOnDropSingle;
6970
use crate::physical_plan::file_format::parquet::page_filter::PagePruningPredicate;
7071
pub use metrics::ParquetFileMetrics;
7172

@@ -706,45 +707,44 @@ pub async fn plan_to_parquet(
706707
let path = path.as_ref();
707708
// create directory to contain the Parquet files (one per partition)
708709
let fs_path = std::path::Path::new(path);
709-
match fs::create_dir(fs_path) {
710-
Ok(()) => {
711-
let mut tasks = vec![];
712-
for i in 0..plan.output_partitioning().partition_count() {
713-
let plan = plan.clone();
714-
let filename = format!("part-{i}.parquet");
715-
let path = fs_path.join(filename);
716-
let file = fs::File::create(path)?;
717-
let mut writer =
718-
ArrowWriter::try_new(file, plan.schema(), writer_properties.clone())?;
719-
let task_ctx = Arc::new(TaskContext::from(state));
720-
let stream = plan.execute(i, task_ctx)?;
721-
let handle: tokio::task::JoinHandle<Result<()>> =
722-
tokio::task::spawn(async move {
723-
stream
724-
.map(|batch| {
725-
writer
726-
.write(&batch?)
727-
.map_err(DataFusionError::ParquetError)
728-
})
729-
.try_collect()
730-
.await
731-
.map_err(DataFusionError::from)?;
732-
writer.close().map_err(DataFusionError::from).map(|_| ())
733-
});
734-
tasks.push(handle);
735-
}
736-
futures::future::join_all(tasks)
737-
.await
738-
.into_iter()
739-
.try_for_each(|result| {
740-
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
741-
})?;
742-
Ok(())
743-
}
744-
Err(e) => Err(DataFusionError::Execution(format!(
710+
if let Err(e) = fs::create_dir(fs_path) {
711+
return Err(DataFusionError::Execution(format!(
745712
"Could not create directory {path}: {e:?}"
746-
))),
713+
)));
714+
}
715+
716+
let mut tasks = vec![];
717+
for i in 0..plan.output_partitioning().partition_count() {
718+
let plan = plan.clone();
719+
let filename = format!("part-{i}.parquet");
720+
let path = fs_path.join(filename);
721+
let file = fs::File::create(path)?;
722+
let mut writer =
723+
ArrowWriter::try_new(file, plan.schema(), writer_properties.clone())?;
724+
let task_ctx = Arc::new(TaskContext::from(state));
725+
let stream = plan.execute(i, task_ctx)?;
726+
let handle: tokio::task::JoinHandle<Result<()>> =
727+
tokio::task::spawn(async move {
728+
stream
729+
.map(|batch| {
730+
writer.write(&batch?).map_err(DataFusionError::ParquetError)
731+
})
732+
.try_collect()
733+
.await
734+
.map_err(DataFusionError::from)?;
735+
736+
writer.close().map_err(DataFusionError::from).map(|_| ())
737+
});
738+
tasks.push(AbortOnDropSingle::new(handle));
747739
}
740+
741+
futures::future::join_all(tasks)
742+
.await
743+
.into_iter()
744+
.try_for_each(|result| {
745+
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
746+
})?;
747+
Ok(())
748748
}
749749

750750
// Copy from the arrow-rs

0 commit comments

Comments
 (0)