feat: implement zlib-stream compression

This also changes how gateway messages work.
Now each gateway backend converts its message into an
intermediary RawGatewayMessage, from which we inflate
and parse GatewayMessages.

Thanks to ByteAlex and their zlib-stream-rs crate, which
helped me understand how to parse a compressed websocket stream
This commit is contained in:
kozabrada123 2024-06-16 08:01:02 +02:00
parent 007fa657ba
commit 2887794b0e
8 changed files with 219 additions and 25 deletions

20
Cargo.lock generated
View File

@ -231,6 +231,7 @@ dependencies = [
"crypto_secretbox",
"custom_error",
"discortp",
"flate2",
"futures-util",
"getrandom",
"hostname",
@ -355,6 +356,15 @@ version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5"
[[package]]
name = "crc32fast"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3"
dependencies = [
"cfg-if",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.11"
@ -555,6 +565,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6"
[[package]]
name = "flate2"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]]
name = "flume"
version = "0.11.0"

View File

@ -16,7 +16,7 @@ default = ["client", "rt-multi-thread"]
backend = ["poem", "sqlx"]
rt-multi-thread = ["tokio/rt-multi-thread"]
rt = ["tokio/rt"]
client = []
client = ["flate2"]
voice = ["voice_udp", "voice_gateway"]
voice_udp = ["dep:discortp", "dep:crypto_secretbox"]
voice_gateway = []
@ -56,6 +56,7 @@ sqlx = { version = "0.7.3", features = [
discortp = { version = "0.5.0", optional = true, features = ["rtp", "discord", "demux"] }
crypto_secretbox = { version = "0.1.1", optional = true }
rand = "0.8.5"
flate2 = { version = "1.0.30", optional = true }
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
rustls = "0.21.10"

View File

@ -12,7 +12,7 @@ use tokio_tungstenite::{
connect_async_tls_with_config, tungstenite, Connector, MaybeTlsStream, WebSocketStream,
};
use crate::gateway::GatewayMessage;
use crate::gateway::{GatewayMessage, RawGatewayMessage};
#[derive(Debug, Clone)]
pub struct TungsteniteBackend;
@ -80,3 +80,22 @@ impl From<tungstenite::Message> for GatewayMessage {
Self(value.to_string())
}
}
impl From<RawGatewayMessage> for tungstenite::Message {
fn from(message: RawGatewayMessage) -> Self {
match message {
RawGatewayMessage::Text(text) => tungstenite::Message::Text(text),
RawGatewayMessage::Bytes(bytes) => tungstenite::Message::Binary(bytes),
}
}
}
impl From<tungstenite::Message> for RawGatewayMessage {
fn from(value: tungstenite::Message) -> Self {
match value {
tungstenite::Message::Binary(bytes) => RawGatewayMessage::Bytes(bytes),
tungstenite::Message::Text(text) => RawGatewayMessage::Text(text),
_ => RawGatewayMessage::Text(value.to_string()),
}
}
}

View File

@ -4,6 +4,7 @@
use std::time::Duration;
use flate2::Decompress;
use futures_util::{SinkExt, StreamExt};
use log::*;
#[cfg(not(target_arch = "wasm32"))]
@ -19,6 +20,9 @@ use crate::types::{
WebSocketEvent,
};
/// Tells us we have received enough of the buffer to decompress it
const ZLIB_SUFFIX: [u8; 4] = [0, 0, 255, 255];
#[derive(Debug)]
pub struct Gateway {
events: Arc<Mutex<Events>>,
@ -28,21 +32,36 @@ pub struct Gateway {
kill_send: tokio::sync::broadcast::Sender<()>,
kill_receive: tokio::sync::broadcast::Receiver<()>,
store: Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>,
/// Url which was used to initialize the gateway
url: String,
/// Options which were used to initialize the gateway
options: GatewayOptions,
zlib_inflate: Option<flate2::Decompress>,
zlib_buffer: Option<Vec<u8>>,
}
impl Gateway {
#[allow(clippy::new_ret_no_self)]
pub async fn spawn(websocket_url: String) -> Result<GatewayHandle, GatewayError> {
let (websocket_send, mut websocket_receive) =
match WebSocketBackend::connect(&websocket_url).await {
Ok(streams) => streams,
Err(e) => {
return Err(GatewayError::CannotConnect {
error: format!("{:?}", e),
})
}
};
/// Creates / opens a new gateway connection.
///
/// # Note
/// The websocket url should begin with the prefix wss:// or ws:// (for unsecure connections)
pub async fn spawn(
websocket_url: String,
options: GatewayOptions,
) -> Result<GatewayHandle, GatewayError> {
let url = options.add_to_url(websocket_url);
debug!("GW: Connecting to {}", url);
let (websocket_send, mut websocket_receive) = match WebSocketBackend::connect(&url).await {
Ok(streams) => streams,
Err(e) => {
return Err(GatewayError::CannotConnect {
error: format!("{:?}", e),
});
}
};
let shared_websocket_send = Arc::new(Mutex::new(websocket_send));
@ -52,10 +71,32 @@ impl Gateway {
// Wait for the first hello and then spawn both tasks so we avoid nested tasks
// This automatically spawns the heartbeat task, but from the main thread
#[cfg(not(target_arch = "wasm32"))]
let msg: GatewayMessage = websocket_receive.next().await.unwrap().unwrap().into();
let received: RawGatewayMessage = websocket_receive.next().await.unwrap().unwrap().into();
#[cfg(target_arch = "wasm32")]
let msg: GatewayMessage = websocket_receive.next().await.unwrap().into();
let gateway_payload: types::GatewayReceivePayload = serde_json::from_str(&msg.0).unwrap();
let received: RawGatewayMessage = websocket_receive.next().await.unwrap().into();
let message: GatewayMessage;
let zlib_buffer;
let zlib_inflate;
match options.transport_compression {
GatewayTransportCompression::None => {
zlib_buffer = None;
zlib_inflate = None;
message = GatewayMessage::from_raw_json_message(received).unwrap();
}
GatewayTransportCompression::ZLibStream => {
zlib_buffer = Some(Vec::new());
let mut inflate = Decompress::new(true);
message = GatewayMessage::from_zlib_stream_json_message(received, &mut inflate).unwrap();
zlib_inflate = Some(inflate);
}
}
let gateway_payload: types::GatewayReceivePayload = serde_json::from_str(&message.0).unwrap();
if gateway_payload.op_code != GATEWAY_HELLO {
return Err(GatewayError::NonHelloOnInitiate {
@ -85,7 +126,10 @@ impl Gateway {
kill_send: kill_send.clone(),
kill_receive: kill_send.subscribe(),
store: store.clone(),
url: websocket_url.clone(),
url: url.clone(),
options,
zlib_inflate,
zlib_buffer,
};
// Now we can continuously check for messages in a different task, since we aren't going to receive another hello
@ -99,7 +143,7 @@ impl Gateway {
});
Ok(GatewayHandle {
url: websocket_url.clone(),
url: url.clone(),
events: shared_events,
websocket_send: shared_websocket_send.clone(),
kill_send: kill_send.clone(),
@ -108,7 +152,7 @@ impl Gateway {
}
/// The main gateway listener task;
pub async fn gateway_listen_task(&mut self) {
async fn gateway_listen_task(&mut self) {
loop {
let msg;
@ -125,12 +169,12 @@ impl Gateway {
// PRETTYFYME: Remove inline conditional compiling
#[cfg(not(target_arch = "wasm32"))]
if let Some(Ok(message)) = msg {
self.handle_message(message.into()).await;
self.handle_raw_message(message.into()).await;
continue;
}
#[cfg(target_arch = "wasm32")]
if let Some(message) = msg {
self.handle_message(message.into()).await;
self.handle_raw_message(message.into()).await;
continue;
}
@ -163,8 +207,41 @@ impl Gateway {
Ok(())
}
/// Takes a [RawGatewayMessage], converts it to [GatewayMessage] based
/// of connection options and calls handle_message
async fn handle_raw_message(&mut self, raw_message: RawGatewayMessage) {
let message;
match self.options.transport_compression {
GatewayTransportCompression::None => {
message = GatewayMessage::from_raw_json_message(raw_message).unwrap()
}
GatewayTransportCompression::ZLibStream => {
let message_bytes = raw_message.to_bytes();
let can_decompress = message_bytes.len() > 4
&& message_bytes[message_bytes.len() - 4..] == ZLIB_SUFFIX;
let zlib_buffer = self.zlib_buffer.as_mut().unwrap();
zlib_buffer.extend(message_bytes.clone());
if !can_decompress {
return;
}
let zlib_buffer = self.zlib_buffer.as_ref().unwrap();
let inflate = self.zlib_inflate.as_mut().unwrap();
message = GatewayMessage::from_zlib_stream_json_bytes(zlib_buffer, inflate).unwrap();
self.zlib_buffer = Some(Vec::new());
}
};
self.handle_message(message).await;
}
/// This handles a message as a websocket event and updates its events along with the events' observers
pub async fn handle_message(&mut self, msg: GatewayMessage) {
async fn handle_message(&mut self, msg: GatewayMessage) {
if msg.0.is_empty() {
return;
}

View File

@ -2,11 +2,41 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
use std::string::FromUtf8Error;
use crate::types;
use super::*;
/// Represents a message received from the gateway. This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError].
/// Defines a raw gateway message, being either string json or bytes
///
/// This is used as an intermediary type between types from different websocket implementations
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum RawGatewayMessage {
Text(String),
Bytes(Vec<u8>),
}
impl RawGatewayMessage {
/// Attempt to consume the message into a String, will try to convert binary to utf8
pub fn to_text(self) -> Result<String, FromUtf8Error> {
match self {
RawGatewayMessage::Text(text) => Ok(text),
RawGatewayMessage::Bytes(bytes) => String::from_utf8(bytes),
}
}
/// Consume the message into bytes, will convert text to binary
pub fn to_bytes(self) -> Vec<u8> {
match self {
RawGatewayMessage::Text(text) => text.as_bytes().to_vec(),
RawGatewayMessage::Bytes(bytes) => bytes,
}
}
}
/// Represents a json message received from the gateway.
/// This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError].
/// This struct is used internally when handling messages.
#[derive(Clone, Debug)]
pub struct GatewayMessage(pub String);
@ -44,4 +74,37 @@ impl GatewayMessage {
pub fn payload(&self) -> Result<types::GatewayReceivePayload, serde_json::Error> {
serde_json::from_str(&self.0)
}
/// Create self from an uncompressed json [RawGatewayMessage]
pub(crate) fn from_raw_json_message(
message: RawGatewayMessage,
) -> Result<GatewayMessage, FromUtf8Error> {
let text = message.to_text()?;
Ok(GatewayMessage(text))
}
/// Attempt to create self by decompressing zlib-stream bytes
// Thanks to <https://github.com/ByteAlex/zlib-stream-rs>, their
// code helped a lot with the stream implementation
pub(crate) fn from_zlib_stream_json_bytes(
bytes: &[u8],
inflate: &mut flate2::Decompress,
) -> Result<GatewayMessage, std::io::Error> {
let mut output = Vec::with_capacity(bytes.len() * 10);
let _status = inflate.decompress_vec(bytes, &mut output, flate2::FlushDecompress::Sync)?;
output.shrink_to_fit();
let string = String::from_utf8(output).unwrap();
Ok(GatewayMessage(string))
}
/// Attempt to create self by decompressing a zlib-stream bytes raw message
pub(crate) fn from_zlib_stream_json_message(
message: RawGatewayMessage,
inflate: &mut flate2::Decompress,
) -> Result<GatewayMessage, std::io::Error> {
Self::from_zlib_stream_json_bytes(&message.to_bytes(), inflate)
}
}

View File

@ -10,12 +10,14 @@ pub mod gateway;
pub mod handle;
pub mod heartbeat;
pub mod message;
pub mod options;
pub use backends::*;
pub use gateway::*;
pub use handle::*;
use heartbeat::*;
pub use message::*;
pub use options::*;
use crate::errors::GatewayError;
use crate::types::{Snowflake, WebSocketEvent};

View File

@ -13,7 +13,7 @@ use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::errors::ChorusResult;
use crate::gateway::{Gateway, GatewayHandle};
use crate::gateway::{Gateway, GatewayHandle, GatewayOptions};
use crate::ratelimiter::ChorusRequest;
use crate::types::types::subconfigs::limits::rates::RateLimits;
use crate::types::{
@ -31,6 +31,8 @@ pub struct Instance {
pub limits_information: Option<LimitsInformation>,
#[serde(skip)]
pub client: Client,
#[serde(skip)]
pub gateway_options: GatewayOptions,
}
impl PartialEq for Instance {
@ -104,6 +106,7 @@ impl Instance {
instance_info: GeneralConfiguration::default(),
limits_information: limit_information,
client: Client::new(),
gateway_options: GatewayOptions::default(),
};
instance.instance_info = match instance.general_configuration_schema().await {
Ok(schema) => schema,
@ -139,6 +142,13 @@ impl Instance {
Err(_) => Ok(None),
}
}
/// Sets the [`GatewayOptions`] the instance will use when spawning new connections.
///
/// These options are used on the gateways created when logging in and registering.
pub fn set_gateway_options(&mut self, options: GatewayOptions) {
self.gateway_options = options;
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
@ -215,7 +225,9 @@ impl ChorusUser {
let object = Arc::new(RwLock::new(User::default()));
let wss_url = instance.read().unwrap().urls.wss.clone();
// Dummy gateway object
let gateway = Gateway::spawn(wss_url).await.unwrap();
let gateway = Gateway::spawn(wss_url, GatewayOptions::default())
.await
.unwrap();
ChorusUser {
token,
belongs_to: instance.clone(),

View File

@ -44,7 +44,7 @@ impl VoiceGateway {
pub async fn spawn(websocket_url: String) -> Result<VoiceGatewayHandle, VoiceGatewayError> {
// Append the needed things to the websocket url
let processed_url = format!("wss://{}/?v=7", websocket_url);
trace!("Created voice socket url: {}", processed_url.clone());
trace!("VGW: Connecting to {}", processed_url.clone());
let (websocket_send, mut websocket_receive) =
match WebSocketBackend::connect(&processed_url).await {