Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ message = "Bump [MSRV](https://github.com/awslabs/aws-sdk-rust#supported-rust-ve
references = ["smithy-rs#1318"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "Velfi"

[[smithy-rs]]
message = "Add new trait for HTTP body callbacks. This is the first step to enabling us to implement optional checksum verification of requests and responses."
references = ["smithy-rs#1307"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "Velfi"
23 changes: 20 additions & 3 deletions rust-runtime/aws-smithy-http/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
[package]
name = "aws-smithy-http"
version = "0.0.0-smithy-rs-head"
authors = ["AWS Rust SDK Team <[email protected]>", "Russell Cohen <[email protected]>"]
authors = [
"AWS Rust SDK Team <[email protected]>",
"Russell Cohen <[email protected]>",
]
description = "Smithy HTTP logic for smithy-rs."
edition = "2021"
license = "Apache-2.0"
Expand Down Expand Up @@ -29,16 +32,30 @@ hyper = "0.14"
# ByteStream internals
futures-core = "0.3.14"
tokio = { version = "1.6", optional = true }
tokio-util = { version = "0.6", optional = true}
tokio-util = { version = "0.6", optional = true }

# Checksum algos
crc32fast = "1.3"
crc32c = "0.6"
sha1 = "0.10"
sha2 = "0.10"

[dev-dependencies]
async-stream = "0.3"
futures-util = "0.3"
hyper = { version = "0.14", features = ["stream"] }
pretty_assertions = "1.2"
proptest = "1"
tokio = {version = "1.6", features = ["macros", "rt", "rt-multi-thread", "fs", "io-util"]}
tokio = { version = "1.6", features = [
"macros",
"rt",
"rt-multi-thread",
"fs",
"io-util",
] }
tokio-stream = "0.1.5"
tempfile = "3.2.0"
tracing-test = "0.2.1"

[package.metadata.docs.rs]
all-features = true
Expand Down
75 changes: 70 additions & 5 deletions rust-runtime/aws-smithy-http/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::callback::BodyCallback;
use crate::header::append_merge_header_maps;

pub type Error = Box<dyn StdError + Send + Sync>;

/// SdkBody type
Expand All @@ -32,6 +35,9 @@ pub struct SdkBody {
/// In the event of retry, this function will be called to generate a new body. See
/// [`try_clone()`](SdkBody::try_clone)
rebuild: Option<Arc<dyn (Fn() -> Inner) + Send + Sync>>,
/// A list of callbacks that will be called at various points of this `SdkBody`'s lifecycle
#[pin]
callbacks: Vec<Box<dyn BodyCallback>>,
}

impl Debug for SdkBody {
Expand Down Expand Up @@ -74,6 +80,7 @@ impl SdkBody {
Self {
inner: Inner::Dyn(body),
rebuild: None,
callbacks: Vec::new(),
}
}

Expand All @@ -90,28 +97,32 @@ impl SdkBody {
SdkBody {
inner: initial.inner,
rebuild: Some(Arc::new(move || f().inner)),
callbacks: Vec::new(),
}
}

pub fn taken() -> Self {
Self {
inner: Inner::Taken,
rebuild: None,
callbacks: Vec::new(),
}
}

pub fn empty() -> Self {
Self {
inner: Inner::Once(None),
rebuild: Some(Arc::new(|| Inner::Once(None))),
callbacks: Vec::new(),
}
}

fn poll_inner(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Error>>> {
match self.project().inner.project() {
let mut this = self.project();
let polling_result = match this.inner.project() {
InnerProj::Once(ref mut opt) => {
let data = opt.take();
match data {
Expand All @@ -125,7 +136,29 @@ impl SdkBody {
InnerProj::Taken => {
Poll::Ready(Some(Err("A `Taken` body should never be polled".into())))
}
};

match &polling_result {
// When we get some bytes back from polling, pass those bytes to each callback in turn
Poll::Ready(Some(Ok(bytes))) => {
for callback in this.callbacks.iter_mut() {
// Callbacks can run into errors when reading bytes. They'll be surfaced here
callback.update(bytes)?;
}
}
// When we're done polling for bytes, run each callback's `trailers()` method. If any calls to
// `trailers()` return an error, propagate that error up. Otherwise, continue.
Poll::Ready(None) => {
for callback_result in this.callbacks.iter().map(BodyCallback::trailers) {
if let Err(e) = callback_result {
return Poll::Ready(Some(Err(e)));
}
}
}
_ => (),
}

polling_result
}

/// If possible, return a reference to this body as `&[u8]`
Expand All @@ -143,16 +176,24 @@ impl SdkBody {
pub fn try_clone(&self) -> Option<Self> {
self.rebuild.as_ref().map(|rebuild| {
let next = rebuild();
SdkBody {
let callbacks = self.callbacks.iter().map(BodyCallback::make_new).collect();

Self {
inner: next,
rebuild: self.rebuild.clone(),
callbacks,
}
})
}

pub fn content_length(&self) -> Option<u64> {
self.size_hint().exact()
}

pub fn with_callback(&mut self, callback: Box<dyn BodyCallback>) -> &mut Self {
self.callbacks.push(callback);
self
}
}

impl From<&str> for SdkBody {
Expand All @@ -166,6 +207,7 @@ impl From<Bytes> for SdkBody {
SdkBody {
inner: Inner::Once(Some(bytes.clone())),
rebuild: Some(Arc::new(move || Inner::Once(Some(bytes.clone())))),
callbacks: Vec::new(),
}
}
}
Expand All @@ -175,6 +217,7 @@ impl From<hyper::Body> for SdkBody {
SdkBody {
inner: Inner::Streaming(body),
rebuild: None,
callbacks: Vec::new(),
}
}
}
Expand Down Expand Up @@ -212,7 +255,30 @@ impl http_body::Body for SdkBody {
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
Poll::Ready(Ok(None))
let mut callback_errors = Vec::new();
let header_map = self
.callbacks
.iter()
.filter_map(|callback| {
match callback.trailers() {
Ok(optional_header_map) => optional_header_map,
// early return if a callback encountered an error
Err(e) => {
callback_errors.push(e);

None
}
}
})
// Merge any `HeaderMap`s from the last step together, one by one.
.reduce(append_merge_header_maps);

if callback_errors.is_empty() {
Poll::Ready(Ok(header_map))
} else {
// TODO What's the most useful way to surface multiple errors?
Poll::Ready(Err(callback_errors.pop().unwrap()))
}
}

fn is_end_stream(&self) -> bool {
Expand Down Expand Up @@ -296,10 +362,9 @@ mod test {
let _ = format!("{:?}", body);
}

fn is_send<T: Send + Sync>() {}

#[test]
fn sdk_body_is_send() {
fn is_send<T: Send>() {}
is_send::<SdkBody>()
}
}
21 changes: 19 additions & 2 deletions rust-runtime/aws-smithy-http/src/byte_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
//! ```

use crate::body::SdkBody;
use crate::callback::BodyCallback;
use bytes::Buf;
use bytes::Bytes;
use bytes_utils::SegmentedBuf;
Expand Down Expand Up @@ -293,6 +294,14 @@ impl ByteStream {
));
Ok(ByteStream::new(body))
}

/// Set a callback on this `ByteStream`. The callback's methods will be called at various points
/// throughout this `ByteStream`'s life cycle. See the [`BodyCallback`](BodyCallback) trait for
/// more information.
pub fn with_body_callback(&mut self, body_callback: Box<dyn BodyCallback>) -> &mut Self {
self.0.with_body_callback(body_callback);
self
}
}

impl Default for ByteStream {
Expand Down Expand Up @@ -416,10 +425,11 @@ struct Inner<B> {
}

impl<B> Inner<B> {
pub fn new(body: B) -> Self {
fn new(body: B) -> Self {
Self { body }
}
pub async fn collect(self) -> Result<AggregatedBytes, B::Error>

async fn collect(self) -> Result<AggregatedBytes, B::Error>
where
B: http_body::Body<Data = Bytes>,
{
Expand All @@ -433,6 +443,13 @@ impl<B> Inner<B> {
}
}

impl Inner<SdkBody> {
fn with_body_callback(&mut self, body_callback: Box<dyn BodyCallback>) -> &mut Self {
self.body.with_callback(body_callback);
self
}
}

impl<B> futures_core::stream::Stream for Inner<B>
where
B: http_body::Body,
Expand Down
Loading