Skip to content

Commit 72b8d51

Browse files
use ObjectStore for dataframe writes (#6987)
* use ObjectStore for dataframe writes * handle temp files compatible with mac and windows * cargo fmt * check test * try to fix json tests on windows * fmt and clippy * use AsyncArrowWriter * implement multipart streaming writes * unfill bucket_name and region * revert back to moving buffer around
1 parent d2d506a commit 72b8d51

File tree

4 files changed

+222
-80
lines changed

4 files changed

+222
-80
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::datasource::file_format::file_type::{FileType, GetExt};
19+
use datafusion::datasource::file_format::parquet::ParquetFormat;
20+
use datafusion::datasource::listing::ListingOptions;
21+
use datafusion::error::Result;
22+
use datafusion::prelude::*;
23+
24+
use object_store::aws::AmazonS3Builder;
25+
use std::env;
26+
use std::sync::Arc;
27+
use url::Url;
28+
29+
/// This example demonstrates querying data from AmazonS3 and writing
30+
/// the result of a query back to AmazonS3
31+
#[tokio::main]
32+
async fn main() -> Result<()> {
33+
// create local execution context
34+
let ctx = SessionContext::new();
35+
36+
//enter region and bucket to which your credentials have GET and PUT access
37+
let region = "<bucket-region-here>";
38+
let bucket_name = "<bucket-name-here>";
39+
40+
let s3 = AmazonS3Builder::new()
41+
.with_bucket_name(bucket_name)
42+
.with_region(region)
43+
.with_access_key_id(env::var("AWS_ACCESS_KEY_ID").unwrap())
44+
.with_secret_access_key(env::var("AWS_SECRET_ACCESS_KEY").unwrap())
45+
.build()?;
46+
47+
let path = format!("s3://{bucket_name}");
48+
let s3_url = Url::parse(&path).unwrap();
49+
let arc_s3 = Arc::new(s3);
50+
ctx.runtime_env()
51+
.register_object_store(&s3_url, arc_s3.clone());
52+
53+
let path = format!("s3://{bucket_name}/test_data/");
54+
let file_format = ParquetFormat::default().with_enable_pruning(Some(true));
55+
let listing_options = ListingOptions::new(Arc::new(file_format))
56+
.with_file_extension(FileType::PARQUET.get_ext());
57+
ctx.register_listing_table("test", &path, listing_options, None, None)
58+
.await?;
59+
60+
// execute the query
61+
let df = ctx.sql("SELECT * from test").await?;
62+
63+
let out_path = format!("s3://{bucket_name}/test_write/");
64+
df.clone().write_parquet(&out_path, None).await?;
65+
66+
//write as JSON to s3
67+
let json_out = format!("s3://{bucket_name}/json_out");
68+
df.clone().write_json(&json_out).await?;
69+
70+
//write as csv to s3
71+
let csv_out = format!("s3://{bucket_name}/csv_out");
72+
df.write_csv(&csv_out).await?;
73+
74+
let file_format = ParquetFormat::default().with_enable_pruning(Some(true));
75+
let listing_options = ListingOptions::new(Arc::new(file_format))
76+
.with_file_extension(FileType::PARQUET.get_ext());
77+
ctx.register_listing_table("test2", &out_path, listing_options, None, None)
78+
.await?;
79+
80+
let df = ctx
81+
.sql(
82+
"SELECT * \
83+
FROM test2 \
84+
",
85+
)
86+
.await?;
87+
88+
df.show_limit(20).await?;
89+
90+
Ok(())
91+
}

datafusion/core/src/datasource/physical_plan/csv.rs

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//! Execution plan for reading CSV files
1919
2020
use crate::datasource::file_format::file_type::FileCompressionType;
21-
use crate::datasource::listing::FileRange;
21+
use crate::datasource::listing::{FileRange, ListingTableUrl};
2222
use crate::datasource::physical_plan::file_stream::{
2323
FileOpenFuture, FileOpener, FileStream,
2424
};
@@ -34,6 +34,7 @@ use arrow::csv;
3434
use arrow::datatypes::SchemaRef;
3535
use datafusion_execution::TaskContext;
3636
use datafusion_physical_expr::{LexOrdering, OrderingEquivalenceProperties};
37+
use tokio::io::AsyncWriteExt;
3738

3839
use super::FileScanConfig;
3940

@@ -43,10 +44,8 @@ use futures::{StreamExt, TryStreamExt};
4344
use object_store::local::LocalFileSystem;
4445
use object_store::{GetOptions, GetResult, ObjectStore};
4546
use std::any::Any;
46-
use std::fs;
4747
use std::io::Cursor;
4848
use std::ops::Range;
49-
use std::path::Path;
5049
use std::sync::Arc;
5150
use std::task::Poll;
5251
use tokio::task::JoinSet;
@@ -566,30 +565,37 @@ pub async fn plan_to_csv(
566565
path: impl AsRef<str>,
567566
) -> Result<()> {
568567
let path = path.as_ref();
569-
// create directory to contain the CSV files (one per partition)
570-
let fs_path = Path::new(path);
571-
if let Err(e) = fs::create_dir(fs_path) {
572-
return Err(DataFusionError::Execution(format!(
573-
"Could not create directory {path}: {e:?}"
574-
)));
575-
}
576-
568+
let parsed = ListingTableUrl::parse(path)?;
569+
let object_store_url = parsed.object_store();
570+
let store = task_ctx.runtime_env().object_store(&object_store_url)?;
577571
let mut join_set = JoinSet::new();
578572
for i in 0..plan.output_partitioning().partition_count() {
579-
let plan = plan.clone();
580-
let filename = format!("part-{i}.csv");
581-
let path = fs_path.join(filename);
582-
let file = fs::File::create(path)?;
583-
let mut writer = csv::Writer::new(file);
584-
let stream = plan.execute(i, task_ctx.clone())?;
573+
let storeref = store.clone();
574+
let plan: Arc<dyn ExecutionPlan> = plan.clone();
575+
let filename = format!("{}/part-{i}.csv", parsed.prefix());
576+
let file = object_store::path::Path::parse(filename)?;
585577

578+
let mut stream = plan.execute(i, task_ctx.clone())?;
586579
join_set.spawn(async move {
587-
let result: Result<()> = stream
588-
.map(|batch| writer.write(&batch?))
589-
.try_collect()
580+
let (_, mut multipart_writer) = storeref.put_multipart(&file).await?;
581+
let mut buffer = Vec::with_capacity(1024);
582+
//only write headers on first iteration
583+
let mut write_headers = true;
584+
while let Some(batch) = stream.next().await.transpose()? {
585+
let mut writer = csv::WriterBuilder::new()
586+
.has_headers(write_headers)
587+
.build(buffer);
588+
writer.write(&batch)?;
589+
buffer = writer.into_inner();
590+
multipart_writer.write_all(&buffer).await?;
591+
buffer.clear();
592+
//prevent writing headers more than once
593+
write_headers = false;
594+
}
595+
multipart_writer
596+
.shutdown()
590597
.await
591-
.map_err(DataFusionError::from);
592-
result
598+
.map_err(DataFusionError::from)
593599
});
594600
}
595601

@@ -1033,14 +1039,20 @@ mod tests {
10331039
#[tokio::test]
10341040
async fn write_csv_results_error_handling() -> Result<()> {
10351041
let ctx = SessionContext::new();
1042+
1043+
// register a local file system object store
1044+
let tmp_dir = TempDir::new()?;
1045+
let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
1046+
let local_url = Url::parse("file://local").unwrap();
1047+
ctx.runtime_env().register_object_store(&local_url, local);
10361048
let options = CsvReadOptions::default()
10371049
.schema_infer_max_records(2)
10381050
.has_header(true);
10391051
let df = ctx.read_csv("tests/data/corrupt.csv", options).await?;
1040-
let tmp_dir = TempDir::new()?;
1041-
let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out";
1052+
1053+
let out_dir_url = "file://local/out";
10421054
let e = df
1043-
.write_csv(&out_dir)
1055+
.write_csv(out_dir_url)
10441056
.await
10451057
.expect_err("should fail because input file does not match inferred schema");
10461058
assert_eq!("Arrow error: Parser error: Error while parsing value d for column 0 at line 4", format!("{e}"));
@@ -1064,10 +1076,18 @@ mod tests {
10641076
)
10651077
.await?;
10661078

1079+
// register a local file system object store
1080+
let tmp_dir = TempDir::new()?;
1081+
let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
1082+
let local_url = Url::parse("file://local").unwrap();
1083+
1084+
ctx.runtime_env().register_object_store(&local_url, local);
1085+
10671086
// execute a simple query and write the results to CSV
10681087
let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out";
1088+
let out_dir_url = "file://local/out";
10691089
let df = ctx.sql("SELECT c1, c2 FROM test").await?;
1070-
df.write_csv(&out_dir).await?;
1090+
df.write_csv(out_dir_url).await?;
10711091

10721092
// create a new context and verify that the results were saved to a partitioned csv file
10731093
let ctx = SessionContext::new();

datafusion/core/src/datasource/physical_plan/json.rs

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
//! Execution plan for reading line-delimited JSON files
1919
use crate::datasource::file_format::file_type::FileCompressionType;
20+
use crate::datasource::listing::ListingTableUrl;
2021
use crate::datasource::physical_plan::file_stream::{
2122
FileOpenFuture, FileOpener, FileStream,
2223
};
@@ -36,13 +37,13 @@ use datafusion_physical_expr::{LexOrdering, OrderingEquivalenceProperties};
3637

3738
use bytes::{Buf, Bytes};
3839
use futures::{ready, stream, StreamExt, TryStreamExt};
40+
use object_store;
3941
use object_store::{GetResult, ObjectStore};
4042
use std::any::Any;
41-
use std::fs;
4243
use std::io::BufReader;
43-
use std::path::Path;
4444
use std::sync::Arc;
4545
use std::task::Poll;
46+
use tokio::io::AsyncWriteExt;
4647
use tokio::task::JoinSet;
4748

4849
use super::FileScanConfig;
@@ -259,29 +260,31 @@ pub async fn plan_to_json(
259260
path: impl AsRef<str>,
260261
) -> Result<()> {
261262
let path = path.as_ref();
262-
// create directory to contain the CSV files (one per partition)
263-
let fs_path = Path::new(path);
264-
if let Err(e) = fs::create_dir(fs_path) {
265-
return Err(DataFusionError::Execution(format!(
266-
"Could not create directory {path}: {e:?}"
267-
)));
268-
}
269-
263+
let parsed = ListingTableUrl::parse(path)?;
264+
let object_store_url = parsed.object_store();
265+
let store = task_ctx.runtime_env().object_store(&object_store_url)?;
270266
let mut join_set = JoinSet::new();
271267
for i in 0..plan.output_partitioning().partition_count() {
272-
let plan = plan.clone();
273-
let filename = format!("part-{i}.json");
274-
let path = fs_path.join(filename);
275-
let file = fs::File::create(path)?;
276-
let mut writer = json::LineDelimitedWriter::new(file);
277-
let stream = plan.execute(i, task_ctx.clone())?;
268+
let storeref = store.clone();
269+
let plan: Arc<dyn ExecutionPlan> = plan.clone();
270+
let filename = format!("{}/part-{i}.json", parsed.prefix());
271+
let file = object_store::path::Path::parse(filename)?;
272+
273+
let mut stream = plan.execute(i, task_ctx.clone())?;
278274
join_set.spawn(async move {
279-
let result: Result<()> = stream
280-
.map(|batch| writer.write(&batch?))
281-
.try_collect()
275+
let (_, mut multipart_writer) = storeref.put_multipart(&file).await?;
276+
let mut buffer = Vec::with_capacity(1024);
277+
while let Some(batch) = stream.next().await.transpose()? {
278+
let mut writer = json::LineDelimitedWriter::new(buffer);
279+
writer.write(&batch)?;
280+
buffer = writer.into_inner();
281+
multipart_writer.write_all(&buffer).await?;
282+
buffer.clear();
283+
}
284+
multipart_writer
285+
.shutdown()
282286
.await
283-
.map_err(DataFusionError::from);
284-
result
287+
.map_err(DataFusionError::from)
285288
});
286289
}
287290

@@ -320,6 +323,7 @@ mod tests {
320323
use crate::test::partitioned_file_groups;
321324
use datafusion_common::cast::{as_int32_array, as_int64_array, as_string_array};
322325
use rstest::*;
326+
use std::path::Path;
323327
use tempfile::TempDir;
324328
use url::Url;
325329

@@ -649,7 +653,6 @@ mod tests {
649653
#[tokio::test]
650654
async fn write_json_results() -> Result<()> {
651655
// create partitioned input file and context
652-
let tmp_dir = TempDir::new()?;
653656
let ctx =
654657
SessionContext::with_config(SessionConfig::new().with_target_partitions(8));
655658

@@ -659,10 +662,17 @@ mod tests {
659662
ctx.register_json("test", path.as_str(), NdJsonReadOptions::default())
660663
.await?;
661664

665+
// register a local file system object store for /tmp directory
666+
let tmp_dir = TempDir::new()?;
667+
let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
668+
let local_url = Url::parse("file://local").unwrap();
669+
ctx.runtime_env().register_object_store(&local_url, local);
670+
662671
// execute a simple query and write the results to CSV
663672
let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out";
673+
let out_dir_url = "file://local/out";
664674
let df = ctx.sql("SELECT a, b FROM test").await?;
665-
df.write_json(&out_dir).await?;
675+
df.write_json(out_dir_url).await?;
666676

667677
// create a new context and verify that the results were saved to a partitioned csv file
668678
let ctx = SessionContext::new();
@@ -720,14 +730,18 @@ mod tests {
720730
#[tokio::test]
721731
async fn write_json_results_error_handling() -> Result<()> {
722732
let ctx = SessionContext::new();
733+
// register a local file system object store for /tmp directory
734+
let tmp_dir = TempDir::new()?;
735+
let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
736+
let local_url = Url::parse("file://local").unwrap();
737+
ctx.runtime_env().register_object_store(&local_url, local);
723738
let options = CsvReadOptions::default()
724739
.schema_infer_max_records(2)
725740
.has_header(true);
726741
let df = ctx.read_csv("tests/data/corrupt.csv", options).await?;
727-
let tmp_dir = TempDir::new()?;
728-
let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out";
742+
let out_dir_url = "file://local/out";
729743
let e = df
730-
.write_json(&out_dir)
744+
.write_json(out_dir_url)
731745
.await
732746
.expect_err("should fail because input file does not match inferred schema");
733747
assert_eq!("Arrow error: Parser error: Error while parsing value d for column 0 at line 4", format!("{e}"));

0 commit comments

Comments
 (0)