Skip to content

Remove anyhow and use pure thiserror impl #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ keywords = ["minecraft", "mc", "serverlistping"]
categories = ["asynchronous", "api-bindings"]

[dependencies]
anyhow = "1.0"
async-trait = "0.1"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
Expand Down
86 changes: 56 additions & 30 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,33 @@

use std::io::Cursor;

use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

#[derive(Error, Debug)]
pub enum ProtocolError {
#[error("error reading or writing data")]
IoError,
#[error(transparent)]
Generic(#[from] ProtocolErrorKind),
#[error("{}: {}",context, source)]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you run rustfmt on this?

WithContext {
context: &'static str,
#[source]
source: ProtocolErrorKind
}
}

impl ProtocolError {
fn with_context<T: Into<ProtocolErrorKind>>(e: T, context: &'static str) -> Self {
let v: ProtocolErrorKind = e.into();
v.context(context)
}
}

#[derive(Error, Debug)]
pub enum ProtocolErrorKind {
#[error("reading or writing data")]
IoError(#[from] std::io::Error),

#[error("invalid varint data")]
InvalidVarInt,
Expand All @@ -25,12 +43,20 @@ pub enum ProtocolError {
InvalidResponseBody,
}

impl From<std::io::Error> for ProtocolError {
fn from(_err: std::io::Error) -> Self {
ProtocolError::IoError
impl ProtocolErrorKind {
/// Wrap ProtocolErrorKind into ProtocolError with context
fn context(self, context: &'static str) -> ProtocolError {
ProtocolError::WithContext {
source: self,
context,
}
}
}

type Result<T> = std::result::Result<T,ProtocolError>;
// used to simulate an additional context by returning only a detailed Kind without Context to the caller
type ResultInner<T> = std::result::Result<T,ProtocolErrorKind>;

/// State represents the desired next state of the
/// exchange.
///
Expand Down Expand Up @@ -74,13 +100,13 @@ impl RawPacket {
/// string support to things that implement AsyncRead.
#[async_trait]
pub trait AsyncWireReadExt {
async fn read_varint(&mut self) -> Result<usize>;
async fn read_string(&mut self) -> Result<String>;
async fn read_varint(&mut self) -> ResultInner<usize>;
async fn read_string(&mut self) -> ResultInner<String>;
}

#[async_trait]
impl<R: AsyncRead + Unpin + Send + Sync> AsyncWireReadExt for R {
async fn read_varint(&mut self) -> Result<usize> {
async fn read_varint(&mut self) -> ResultInner<usize> {
let mut read = 0;
let mut result = 0;
loop {
Expand All @@ -89,35 +115,35 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncWireReadExt for R {
result |= (value as usize) << (7 * read);
read += 1;
if read > 5 {
bail!(ProtocolError::InvalidVarInt);
return Err(ProtocolErrorKind::InvalidVarInt);
}
if (read_value & 0b1000_0000) == 0 {
return Ok(result);
}
}
}

async fn read_string(&mut self) -> Result<String> {
async fn read_string(&mut self) -> ResultInner<String> {
let length = self.read_varint().await?;

let mut buffer = vec![0; length];
self.read_exact(&mut buffer).await?;

Ok(String::from_utf8(buffer).map_err(|_| ProtocolError::InvalidResponseBody)?)
Ok(String::from_utf8(buffer).map_err(|_| ProtocolErrorKind::InvalidResponseBody)?)
}
}

/// AsyncWireWriteExt adds varint and varint-backed
/// string support to things that implement AsyncWrite.
#[async_trait]
pub trait AsyncWireWriteExt {
async fn write_varint(&mut self, int: usize) -> Result<()>;
async fn write_string(&mut self, string: &str) -> Result<()>;
async fn write_varint(&mut self, int: usize) -> ResultInner<()>;
async fn write_string(&mut self, string: &str) -> ResultInner<()>;
}

#[async_trait]
impl<W: AsyncWrite + Unpin + Send + Sync> AsyncWireWriteExt for W {
async fn write_varint(&mut self, int: usize) -> Result<()> {
async fn write_varint(&mut self, int: usize) -> ResultInner<()> {
let mut int = (int as u64) & 0xFFFF_FFFF;
let mut written = 0;
let mut buffer = [0; 5];
Expand All @@ -139,7 +165,7 @@ impl<W: AsyncWrite + Unpin + Send + Sync> AsyncWireWriteExt for W {
Ok(())
}

async fn write_string(&mut self, string: &str) -> Result<()> {
async fn write_string(&mut self, string: &str) -> ResultInner<()> {
self.write_varint(string.len()).await?;
self.write_all(string.as_bytes()).await?;

Expand Down Expand Up @@ -194,25 +220,25 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
let length = self
.read_varint()
.await
.context("failed to read packet length")?;
.map_err(|e| e.context("failed to read packet length"))?;
let packet_id = self
.read_varint()
.await
.context("failed to read packet ID")?;
.map_err(|e| e.context("failed to read packet ID"))?;

let expected_packet_id = T::get_expected_packet_id();

if packet_id != expected_packet_id {
bail!(ProtocolError::InvalidPacketId {
return Err(ProtocolErrorKind::InvalidPacketId {
expected: expected_packet_id,
actual: packet_id,
});
}.into());
}

let mut buffer = vec![0; length - 1];
self.read_exact(&mut buffer)
.await
.context("failed to read packet body")?;
.map_err(|e|ProtocolError::with_context(e,"failed to read packet body"))?;

T::read_from_buffer(buffer).await
}
Expand Down Expand Up @@ -245,19 +271,19 @@ impl<W: AsyncWrite + Unpin + Send + Sync> AsyncWriteRawPacket for W {
buffer
.write_varint(raw_packet.id)
.await
.context("failed to write packet ID")?;
.map_err(|e|e.context("failed to write packet ID"))?;
buffer
.write_all(&raw_packet.data)
.await
.context("failed to write packet data")?;
.map_err(|e|ProtocolError::with_context(e,"failed to write packet data"))?;

let inner = buffer.into_inner();
self.write_varint(inner.len())
.await
.context("failed to write packet length")?;
.map_err(|e|e.context("failed to write packet length"))?;
self.write(&inner)
.await
.context("failed to write constructed packet buffer")?;
.map_err(|e|ProtocolError::with_context(e,"failed to write constructed packet buffer"))?;
Ok(())
}
}
Expand Down Expand Up @@ -293,19 +319,19 @@ impl AsyncWriteToBuffer for HandshakePacket {
buffer
.write_varint(self.protocol_version)
.await
.context("failed to write protocol version")?;
.map_err(|e|e.context("failed to write protocol version"))?;
buffer
.write_string(&self.server_address)
.await
.context("failed to write server address")?;
.map_err(|e|e.context("failed to write server address"))?;
buffer
.write_u16(self.server_port)
.await
.context("failed to write server port")?;
.map_err(|e|ProtocolError::with_context(e,"failed to write server port"))?;
buffer
.write_varint(self.next_state.into())
.await
.context("failed to write next state")?;
.map_err(|e|e.context("failed to write next state"))?;

Ok(buffer.into_inner())
}
Expand Down Expand Up @@ -365,7 +391,7 @@ impl AsyncReadFromBuffer for ResponsePacket {
let body = reader
.read_string()
.await
.context("failed to read response body")?;
.map_err(|e|e.context("failed to read response body"))?;

Ok(ResponsePacket { packet_id: 0, body })
}
Expand Down
50 changes: 37 additions & 13 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,47 @@
//! This module defines a wrapper around Minecraft's
//! [ServerListPing](https://wiki.vg/Server_List_Ping)

use anyhow::{Context, Result};
use serde::Deserialize;
use thiserror::Error;
use tokio::net::TcpStream;

use crate::protocol::{self, AsyncReadRawPacket, AsyncWriteRawPacket};

pub type Result<T> = std::result::Result<T, ServerError>;

#[derive(Error, Debug)]
pub enum ServerError {
#[error(transparent)]
Generic(#[from] ServerErrorKind),
#[error("{}: {}",context, source)]
WithContext {
context: &'static str,
#[source]
source: ServerErrorKind
}
}

impl ServerError {
fn with_context<T: Into<ServerErrorKind>>(e: T, context: &'static str) -> Self {
let v: ServerErrorKind = e.into();
v.context(context)
}
}

impl ServerErrorKind {
/// Wrap ProtocolErrorKind into ProtocolError with context
fn context(self, context: &'static str) -> ServerError {
ServerError::WithContext {
source: self,
context,
}
}
}

#[derive(Error, Debug)]
pub enum ServerErrorKind {
#[error("error reading or writing data")]
ProtocolError,
ProtocolError(#[from] protocol::ProtocolError),

#[error("failed to connect to server")]
FailedToConnect,
Expand All @@ -20,12 +50,6 @@ pub enum ServerError {
InvalidJson(String),
}

impl From<protocol::ProtocolError> for ServerError {
fn from(_err: protocol::ProtocolError) -> Self {
ServerError::ProtocolError
}
}

/// Contains information about the server version.
#[derive(Debug, Deserialize)]
pub struct ServerVersion {
Expand Down Expand Up @@ -131,7 +155,7 @@ impl ConnectionConfig {
pub async fn connect(self) -> Result<StatusConnection> {
let stream = TcpStream::connect(format!("{}:{}", self.address, self.port))
.await
.map_err(|_| ServerError::FailedToConnect)?;
.map_err(|_| ServerErrorKind::FailedToConnect)?;

Ok(StatusConnection {
stream,
Expand Down Expand Up @@ -170,20 +194,20 @@ impl StatusConnection {
self.stream
.write_packet(handshake)
.await
.context("failed to write handshake packet")?;
.map_err(|e|ServerError::with_context(e, "failed to write handshake packet"))?;

self.stream
.write_packet(protocol::RequestPacket::new())
.await
.context("failed to write request packet")?;
.map_err(|e|ServerError::with_context(e, "failed to write request packet"))?;

let response: protocol::ResponsePacket = self
.stream
.read_packet()
.await
.context("failed to read response packet")?;
.map_err(|e|ServerError::with_context(e, "failed to read response packet"))?;

Ok(serde_json::from_str(&response.body)
.map_err(|_| ServerError::InvalidJson(response.body))?)
.map_err(|_| ServerErrorKind::InvalidJson(response.body))?)
}
}