Implement gateway options, zlib-stream compression (#508)

* feat: add GatewayOptions

* feat: implement zlib-stream compression

This also changes how gateway messages work.
Now each gateway backend converts its message into an
intermediary RawGatewayMessage, from which we inflate
and parse GatewayMessages.

Thanks to ByteAlex and their zlib-stream-rs crate, which
helped me understand how to parse a compressed websocket stream
This commit is contained in:
kozabrada123 2024-06-23 17:23:13 +02:00 committed by GitHub
parent b4a8082f29
commit 89333d6353
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 525 additions and 76 deletions

View File

@ -124,7 +124,7 @@ jobs:
cargo binstall --no-confirm wasm-bindgen-cli --version "0.2.92" --force cargo binstall --no-confirm wasm-bindgen-cli --version "0.2.92" --force
GECKODRIVER=$(which geckodriver) cargo test --target wasm32-unknown-unknown --no-default-features --features="client, rt, voice_gateway" GECKODRIVER=$(which geckodriver) cargo test --target wasm32-unknown-unknown --no-default-features --features="client, rt, voice_gateway"
wasm-chrome: wasm-chrome:
runs-on: macos-latest runs-on: ubuntu-latest
timeout-minutes: 30 timeout-minutes: 30
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4

31
Cargo.lock generated
View File

@ -231,6 +231,7 @@ dependencies = [
"crypto_secretbox", "crypto_secretbox",
"custom_error", "custom_error",
"discortp", "discortp",
"flate2",
"futures-util", "futures-util",
"getrandom", "getrandom",
"hostname", "hostname",
@ -250,6 +251,7 @@ dependencies = [
"serde_json", "serde_json",
"serde_repr", "serde_repr",
"serde_with", "serde_with",
"simple_logger",
"sqlx", "sqlx",
"thiserror", "thiserror",
"tokio", "tokio",
@ -353,6 +355,15 @@ version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5"
[[package]]
name = "crc32fast"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3"
dependencies = [
"cfg-if",
]
[[package]] [[package]]
name = "crossbeam-queue" name = "crossbeam-queue"
version = "0.3.11" version = "0.3.11"
@ -553,6 +564,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6" checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6"
[[package]]
name = "flate2"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]] [[package]]
name = "flume" name = "flume"
version = "0.11.0" version = "0.11.0"
@ -2124,6 +2145,16 @@ dependencies = [
"time", "time",
] ]
[[package]]
name = "simple_logger"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8c5dfa5e08767553704aa0ffd9d9794d527103c736aba9854773851fd7497eb"
dependencies = [
"log",
"windows-sys 0.48.0",
]
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.9" version = "0.4.9"

View File

@ -16,7 +16,7 @@ default = ["client", "rt-multi-thread"]
backend = ["poem", "sqlx"] backend = ["poem", "sqlx"]
rt-multi-thread = ["tokio/rt-multi-thread"] rt-multi-thread = ["tokio/rt-multi-thread"]
rt = ["tokio/rt"] rt = ["tokio/rt"]
client = [] client = ["flate2"]
voice = ["voice_udp", "voice_gateway"] voice = ["voice_udp", "voice_gateway"]
voice_udp = ["dep:discortp", "dep:crypto_secretbox"] voice_udp = ["dep:discortp", "dep:crypto_secretbox"]
voice_gateway = [] voice_gateway = []
@ -56,6 +56,7 @@ sqlx = { version = "0.7.3", features = [
discortp = { version = "0.5.0", optional = true, features = ["rtp", "discord", "demux"] } discortp = { version = "0.5.0", optional = true, features = ["rtp", "discord", "demux"] }
crypto_secretbox = { version = "0.1.1", optional = true } crypto_secretbox = { version = "0.1.1", optional = true }
rand = "0.8.5" rand = "0.8.5"
flate2 = { version = "1.0.30", optional = true }
[target.'cfg(not(target_arch = "wasm32"))'.dependencies] [target.'cfg(not(target_arch = "wasm32"))'.dependencies]
rustls = "0.21.10" rustls = "0.21.10"
@ -78,3 +79,4 @@ wasmtimer = "0.2.0"
lazy_static = "1.4.0" lazy_static = "1.4.0"
wasm-bindgen-test = "0.3.42" wasm-bindgen-test = "0.3.42"
wasm-bindgen = "0.2.92" wasm-bindgen = "0.2.92"
simple_logger = { version = "5.0.0", default-features=false }

View File

@ -3,6 +3,8 @@
// file, You can obtain one at http://mozilla.org/MPL/2.0/. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
// This example showcase how to properly use gateway observers. // This example showcase how to properly use gateway observers.
// (This assumes you have a manually created gateway, if you created
// a ChorusUser by e.g. logging in, you can access the gateway with user.gateway)
// //
// To properly run it, you will need to change the token below. // To properly run it, you will need to change the token below.
@ -12,7 +14,7 @@ const TOKEN: &str = "";
const GATEWAY_URL: &str = "wss://gateway.old.server.spacebar.chat/"; const GATEWAY_URL: &str = "wss://gateway.old.server.spacebar.chat/";
use async_trait::async_trait; use async_trait::async_trait;
use chorus::gateway::Gateway; use chorus::gateway::{Gateway, GatewayOptions};
use chorus::{ use chorus::{
self, self,
gateway::Observer, gateway::Observer,
@ -47,8 +49,14 @@ impl Observer<GatewayReady> for ExampleObserver {
async fn main() { async fn main() {
let gateway_websocket_url = GATEWAY_URL.to_string(); let gateway_websocket_url = GATEWAY_URL.to_string();
// These options specify the encoding format, compression, etc
//
// For most cases the defaults should work, though some implementations
// might only support some formats or not support compression
let options = GatewayOptions::default();
// Initiate the gateway connection // Initiate the gateway connection
let gateway = Gateway::spawn(gateway_websocket_url).await.unwrap(); let gateway = Gateway::spawn(gateway_websocket_url, options).await.unwrap();
// Create an instance of our observer // Create an instance of our observer
let observer = ExampleObserver {}; let observer = ExampleObserver {};

View File

@ -3,7 +3,7 @@
// file, You can obtain one at http://mozilla.org/MPL/2.0/. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
// This example showcases how to initiate a gateway connection manually // This example showcases how to initiate a gateway connection manually
// (e. g. not through ChorusUser) // (e. g. not through ChorusUser or Instance)
// //
// To properly run it, you will need to modify the token below. // To properly run it, you will need to modify the token below.
@ -14,7 +14,7 @@ const GATEWAY_URL: &str = "wss://gateway.old.server.spacebar.chat/";
use std::time::Duration; use std::time::Duration;
use chorus::gateway::Gateway; use chorus::gateway::{Gateway, GatewayOptions};
use chorus::{self, types::GatewayIdentifyPayload}; use chorus::{self, types::GatewayIdentifyPayload};
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
@ -27,8 +27,14 @@ use wasmtimer::tokio::sleep;
async fn main() { async fn main() {
let gateway_websocket_url = GATEWAY_URL.to_string(); let gateway_websocket_url = GATEWAY_URL.to_string();
// These options specify the encoding format, compression, etc
//
// For most cases the defaults should work, though some implementations
// might only support some formats or not support compression
let options = GatewayOptions::default();
// Initiate the gateway connection, starting a listener in one thread and a heartbeat handler in another // Initiate the gateway connection, starting a listener in one thread and a heartbeat handler in another
let gateway = Gateway::spawn(gateway_websocket_url).await.unwrap(); let gateway = Gateway::spawn(gateway_websocket_url, options).await.unwrap();
// At this point, we are connected to the server and are sending heartbeats, however we still haven't authenticated // At this point, we are connected to the server and are sending heartbeats, however we still haven't authenticated

View File

@ -40,12 +40,18 @@ impl ChorusUser {
/// # Reference /// # Reference
/// See <https://discord-userdoccers.vercel.app/resources/user#modify-current-user> /// See <https://discord-userdoccers.vercel.app/resources/user#modify-current-user>
pub async fn modify(&mut self, modify_schema: UserModifySchema) -> ChorusResult<User> { pub async fn modify(&mut self, modify_schema: UserModifySchema) -> ChorusResult<User> {
if modify_schema.new_password.is_some()
// See <https://docs.discord.sex/resources/user#json-params>, note 1
let requires_current_password = modify_schema.username.is_some()
|| modify_schema.discriminator.is_some()
|| modify_schema.email.is_some() || modify_schema.email.is_some()
|| modify_schema.code.is_some() || modify_schema.date_of_birth.is_some()
{ || modify_schema.new_password.is_some();
if requires_current_password && modify_schema.current_password.is_none() {
return Err(ChorusError::PasswordRequired); return Err(ChorusError::PasswordRequired);
} }
let request = Client::new() let request = Client::new()
.patch(format!( .patch(format!(
"{}/users/@me", "{}/users/@me",
@ -132,4 +138,3 @@ impl User {
} }
} }
} }

View File

@ -12,7 +12,7 @@ use tokio_tungstenite::{
connect_async_tls_with_config, tungstenite, Connector, MaybeTlsStream, WebSocketStream, connect_async_tls_with_config, tungstenite, Connector, MaybeTlsStream, WebSocketStream,
}; };
use crate::gateway::GatewayMessage; use crate::gateway::{GatewayMessage, RawGatewayMessage};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct TungsteniteBackend; pub struct TungsteniteBackend;
@ -80,3 +80,22 @@ impl From<tungstenite::Message> for GatewayMessage {
Self(value.to_string()) Self(value.to_string())
} }
} }
impl From<RawGatewayMessage> for tungstenite::Message {
fn from(message: RawGatewayMessage) -> Self {
match message {
RawGatewayMessage::Text(text) => tungstenite::Message::Text(text),
RawGatewayMessage::Bytes(bytes) => tungstenite::Message::Binary(bytes),
}
}
}
impl From<tungstenite::Message> for RawGatewayMessage {
fn from(value: tungstenite::Message) -> Self {
match value {
tungstenite::Message::Binary(bytes) => RawGatewayMessage::Bytes(bytes),
tungstenite::Message::Text(text) => RawGatewayMessage::Text(text),
_ => RawGatewayMessage::Text(value.to_string()),
}
}
}

View File

@ -9,7 +9,7 @@ use futures_util::{
use ws_stream_wasm::*; use ws_stream_wasm::*;
use crate::gateway::GatewayMessage; use crate::gateway::{GatewayMessage, RawGatewayMessage};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct WasmBackend; pub struct WasmBackend;
@ -46,3 +46,21 @@ impl From<WsMessage> for GatewayMessage {
} }
} }
} }
impl From<RawGatewayMessage> for WsMessage {
fn from(message: RawGatewayMessage) -> Self {
match message {
RawGatewayMessage::Text(text) => WsMessage::Text(text),
RawGatewayMessage::Bytes(bytes) => WsMessage::Binary(bytes),
}
}
}
impl From<WsMessage> for RawGatewayMessage {
fn from(value: WsMessage) -> Self {
match value {
WsMessage::Binary(bytes) => RawGatewayMessage::Bytes(bytes),
WsMessage::Text(text) => RawGatewayMessage::Text(text),
}
}
}

View File

@ -4,6 +4,7 @@
use std::time::Duration; use std::time::Duration;
use flate2::Decompress;
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use log::*; use log::*;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
@ -19,6 +20,9 @@ use crate::types::{
WebSocketEvent, WebSocketEvent,
}; };
/// Tells us we have received enough of the buffer to decompress it
const ZLIB_SUFFIX: [u8; 4] = [0, 0, 255, 255];
#[derive(Debug)] #[derive(Debug)]
pub struct Gateway { pub struct Gateway {
events: Arc<Mutex<Events>>, events: Arc<Mutex<Events>>,
@ -28,19 +32,34 @@ pub struct Gateway {
kill_send: tokio::sync::broadcast::Sender<()>, kill_send: tokio::sync::broadcast::Sender<()>,
kill_receive: tokio::sync::broadcast::Receiver<()>, kill_receive: tokio::sync::broadcast::Receiver<()>,
store: Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>, store: Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>,
/// Url which was used to initialize the gateway
url: String, url: String,
/// Options which were used to initialize the gateway
options: GatewayOptions,
zlib_inflate: Option<flate2::Decompress>,
zlib_buffer: Option<Vec<u8>>,
} }
impl Gateway { impl Gateway {
#[allow(clippy::new_ret_no_self)] #[allow(clippy::new_ret_no_self)]
pub async fn spawn(websocket_url: String) -> Result<GatewayHandle, GatewayError> { /// Creates / opens a new gateway connection.
let (websocket_send, mut websocket_receive) = ///
match WebSocketBackend::connect(&websocket_url).await { /// # Note
/// The websocket url should begin with the prefix wss:// or ws:// (for unsecure connections)
pub async fn spawn(
websocket_url: String,
options: GatewayOptions,
) -> Result<GatewayHandle, GatewayError> {
let url = options.add_to_url(websocket_url);
debug!("GW: Connecting to {}", url);
let (websocket_send, mut websocket_receive) = match WebSocketBackend::connect(&url).await {
Ok(streams) => streams, Ok(streams) => streams,
Err(e) => { Err(e) => {
return Err(GatewayError::CannotConnect { return Err(GatewayError::CannotConnect {
error: format!("{:?}", e), error: format!("{:?}", e),
}) });
} }
}; };
@ -52,10 +71,32 @@ impl Gateway {
// Wait for the first hello and then spawn both tasks so we avoid nested tasks // Wait for the first hello and then spawn both tasks so we avoid nested tasks
// This automatically spawns the heartbeat task, but from the main thread // This automatically spawns the heartbeat task, but from the main thread
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
let msg: GatewayMessage = websocket_receive.next().await.unwrap().unwrap().into(); let received: RawGatewayMessage = websocket_receive.next().await.unwrap().unwrap().into();
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
let msg: GatewayMessage = websocket_receive.next().await.unwrap().into(); let received: RawGatewayMessage = websocket_receive.next().await.unwrap().into();
let gateway_payload: types::GatewayReceivePayload = serde_json::from_str(&msg.0).unwrap();
let message: GatewayMessage;
let zlib_buffer;
let zlib_inflate;
match options.transport_compression {
GatewayTransportCompression::None => {
zlib_buffer = None;
zlib_inflate = None;
message = GatewayMessage::from_raw_json_message(received).unwrap();
}
GatewayTransportCompression::ZLibStream => {
zlib_buffer = Some(Vec::new());
let mut inflate = Decompress::new(true);
message = GatewayMessage::from_zlib_stream_json_message(received, &mut inflate).unwrap();
zlib_inflate = Some(inflate);
}
}
let gateway_payload: types::GatewayReceivePayload = serde_json::from_str(&message.0).unwrap();
if gateway_payload.op_code != GATEWAY_HELLO { if gateway_payload.op_code != GATEWAY_HELLO {
return Err(GatewayError::NonHelloOnInitiate { return Err(GatewayError::NonHelloOnInitiate {
@ -85,7 +126,10 @@ impl Gateway {
kill_send: kill_send.clone(), kill_send: kill_send.clone(),
kill_receive: kill_send.subscribe(), kill_receive: kill_send.subscribe(),
store: store.clone(), store: store.clone(),
url: websocket_url.clone(), url: url.clone(),
options,
zlib_inflate,
zlib_buffer,
}; };
// Now we can continuously check for messages in a different task, since we aren't going to receive another hello // Now we can continuously check for messages in a different task, since we aren't going to receive another hello
@ -99,7 +143,7 @@ impl Gateway {
}); });
Ok(GatewayHandle { Ok(GatewayHandle {
url: websocket_url.clone(), url: url.clone(),
events: shared_events, events: shared_events,
websocket_send: shared_websocket_send.clone(), websocket_send: shared_websocket_send.clone(),
kill_send: kill_send.clone(), kill_send: kill_send.clone(),
@ -108,7 +152,7 @@ impl Gateway {
} }
/// The main gateway listener task; /// The main gateway listener task;
pub async fn gateway_listen_task(&mut self) { async fn gateway_listen_task(&mut self) {
loop { loop {
let msg; let msg;
@ -125,12 +169,12 @@ impl Gateway {
// PRETTYFYME: Remove inline conditional compiling // PRETTYFYME: Remove inline conditional compiling
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
if let Some(Ok(message)) = msg { if let Some(Ok(message)) = msg {
self.handle_message(message.into()).await; self.handle_raw_message(message.into()).await;
continue; continue;
} }
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
if let Some(message) = msg { if let Some(message) = msg {
self.handle_message(message.into()).await; self.handle_raw_message(message.into()).await;
continue; continue;
} }
@ -163,8 +207,41 @@ impl Gateway {
Ok(()) Ok(())
} }
/// Takes a [RawGatewayMessage], converts it to [GatewayMessage] based
/// of connection options and calls handle_message
async fn handle_raw_message(&mut self, raw_message: RawGatewayMessage) {
let message;
match self.options.transport_compression {
GatewayTransportCompression::None => {
message = GatewayMessage::from_raw_json_message(raw_message).unwrap()
}
GatewayTransportCompression::ZLibStream => {
let message_bytes = raw_message.into_bytes();
let can_decompress = message_bytes.len() > 4
&& message_bytes[message_bytes.len() - 4..] == ZLIB_SUFFIX;
let zlib_buffer = self.zlib_buffer.as_mut().unwrap();
zlib_buffer.extend(message_bytes.clone());
if !can_decompress {
return;
}
let zlib_buffer = self.zlib_buffer.as_ref().unwrap();
let inflate = self.zlib_inflate.as_mut().unwrap();
message = GatewayMessage::from_zlib_stream_json_bytes(zlib_buffer, inflate).unwrap();
self.zlib_buffer = Some(Vec::new());
}
};
self.handle_message(message).await;
}
/// This handles a message as a websocket event and updates its events along with the events' observers /// This handles a message as a websocket event and updates its events along with the events' observers
pub async fn handle_message(&mut self, msg: GatewayMessage) { async fn handle_message(&mut self, msg: GatewayMessage) {
if msg.0.is_empty() { if msg.0.is_empty() {
return; return;
} }

View File

@ -2,11 +2,41 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this // License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
use std::string::FromUtf8Error;
use crate::types; use crate::types;
use super::*; use super::*;
/// Represents a message received from the gateway. This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError]. /// Defines a raw gateway message, being either string json or bytes
///
/// This is used as an intermediary type between types from different websocket implementations
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum RawGatewayMessage {
Text(String),
Bytes(Vec<u8>),
}
impl RawGatewayMessage {
/// Attempt to consume the message into a String, will try to convert binary to utf8
pub fn into_text(self) -> Result<String, FromUtf8Error> {
match self {
RawGatewayMessage::Text(text) => Ok(text),
RawGatewayMessage::Bytes(bytes) => String::from_utf8(bytes),
}
}
/// Consume the message into bytes, will convert text to binary
pub fn into_bytes(self) -> Vec<u8> {
match self {
RawGatewayMessage::Text(text) => text.as_bytes().to_vec(),
RawGatewayMessage::Bytes(bytes) => bytes,
}
}
}
/// Represents a json message received from the gateway.
/// This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError].
/// This struct is used internally when handling messages. /// This struct is used internally when handling messages.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct GatewayMessage(pub String); pub struct GatewayMessage(pub String);
@ -44,4 +74,41 @@ impl GatewayMessage {
pub fn payload(&self) -> Result<types::GatewayReceivePayload, serde_json::Error> { pub fn payload(&self) -> Result<types::GatewayReceivePayload, serde_json::Error> {
serde_json::from_str(&self.0) serde_json::from_str(&self.0)
} }
/// Create self from an uncompressed json [RawGatewayMessage]
pub(crate) fn from_raw_json_message(
message: RawGatewayMessage,
) -> Result<GatewayMessage, FromUtf8Error> {
let text = message.into_text()?;
Ok(GatewayMessage(text))
}
/// Attempt to create self by decompressing zlib-stream bytes
// Thanks to <https://github.com/ByteAlex/zlib-stream-rs>, their
// code helped a lot with the stream implementation
pub(crate) fn from_zlib_stream_json_bytes(
bytes: &[u8],
inflate: &mut flate2::Decompress,
) -> Result<GatewayMessage, std::io::Error> {
// Note: is there a better way to handle the size of this output buffer?
//
// This used to be 10, I measured it at 11.5, so a safe bet feels like 20
let mut output = Vec::with_capacity(bytes.len() * 20);
let _status = inflate.decompress_vec(bytes, &mut output, flate2::FlushDecompress::Sync)?;
output.shrink_to_fit();
let string = String::from_utf8(output).unwrap();
Ok(GatewayMessage(string))
}
/// Attempt to create self by decompressing a zlib-stream bytes raw message
pub(crate) fn from_zlib_stream_json_message(
message: RawGatewayMessage,
inflate: &mut flate2::Decompress,
) -> Result<GatewayMessage, std::io::Error> {
Self::from_zlib_stream_json_bytes(&message.into_bytes(), inflate)
}
} }

View File

@ -10,12 +10,14 @@ pub mod gateway;
pub mod handle; pub mod handle;
pub mod heartbeat; pub mod heartbeat;
pub mod message; pub mod message;
pub mod options;
pub use backends::*; pub use backends::*;
pub use gateway::*; pub use gateway::*;
pub use handle::*; pub use handle::*;
use heartbeat::*; use heartbeat::*;
pub use message::*; pub use message::*;
pub use options::*;
use crate::errors::GatewayError; use crate::errors::GatewayError;
use crate::types::{Snowflake, WebSocketEvent}; use crate::types::{Snowflake, WebSocketEvent};

118
src/gateway/options.rs Normal file
View File

@ -0,0 +1,118 @@
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
#[derive(Clone, PartialEq, Eq, Ord, PartialOrd, Debug, Default)]
/// Options passed when initializing the gateway connection.
///
/// E.g. compression
///
/// # Note
///
/// Discord allows specifying the api version (v10, v9, ...) as well, but chorus is built upon one
/// main version (v9).
///
/// Similarly, discord also supports etf encoding, while chorus does not (yet).
/// We are looking into supporting it as an option, since it is faster and more lightweight.
///
/// See <https://docs.discord.sex/topics/gateway#connections>
pub struct GatewayOptions {
pub encoding: GatewayEncoding,
pub transport_compression: GatewayTransportCompression,
}
impl GatewayOptions {
/// Adds the options to an existing gateway url
///
/// Returns the new url
pub(crate) fn add_to_url(&self, url: String) -> String {
let mut url = url;
let mut parameters = Vec::with_capacity(2);
let encoding = self.encoding.to_url_parameter();
parameters.push(encoding);
let compression = self.transport_compression.to_url_parameter();
if let Some(some_compression) = compression {
parameters.push(some_compression);
}
let mut has_parameters = url.contains('?') && url.contains('=');
if !has_parameters {
// Insure it ends in a /, so we don't get a 400 error
if !url.ends_with('/') {
url.push('/');
}
// Lets hope that if it already has parameters the person knew to add '/'
}
for parameter in parameters {
if !has_parameters {
url = format!("{}?{}", url, parameter);
has_parameters = true;
}
else {
url = format!("{}&{}", url, parameter);
}
}
url
}
}
#[derive(Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Debug, Default)]
/// Possible transport compression options for the gateway.
///
/// See <https://docs.discord.sex/topics/gateway#transport-compression>
pub enum GatewayTransportCompression {
/// Do not transport compress packets
None,
/// Transport compress using zlib stream
#[default]
ZLibStream,
}
impl GatewayTransportCompression {
/// Returns the option as a url parameter.
///
/// If set to [GatewayTransportCompression::None] returns [None].
///
/// If set to anything else, returns a string like "compress=zlib-stream"
pub(crate) fn to_url_parameter(self) -> Option<String> {
match self {
Self::None => None,
Self::ZLibStream => Some(String::from("compress=zlib-stream"))
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Debug, Default)]
/// See <https://docs.discord.sex/topics/gateway#encoding-and-compression>
pub enum GatewayEncoding {
/// Javascript object notation, a standard for websocket connections,
/// but contains a lot of overhead
#[default]
Json,
/// A binary format originating from Erlang
///
/// Should be lighter and faster than json.
///
/// !! Chorus does not implement ETF yet !!
ETF
}
impl GatewayEncoding {
/// Returns the option as a url parameter.
///
/// Returns a string like "encoding=json"
pub(crate) fn to_url_parameter(self) -> String {
match self {
Self::Json => String::from("encoding=json"),
Self::ETF => String::from("encoding=etf")
}
}
}

View File

@ -13,7 +13,7 @@ use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::errors::ChorusResult; use crate::errors::ChorusResult;
use crate::gateway::{Gateway, GatewayHandle}; use crate::gateway::{Gateway, GatewayHandle, GatewayOptions};
use crate::ratelimiter::ChorusRequest; use crate::ratelimiter::ChorusRequest;
use crate::types::types::subconfigs::limits::rates::RateLimits; use crate::types::types::subconfigs::limits::rates::RateLimits;
use crate::types::{ use crate::types::{
@ -31,6 +31,8 @@ pub struct Instance {
pub limits_information: Option<LimitsInformation>, pub limits_information: Option<LimitsInformation>,
#[serde(skip)] #[serde(skip)]
pub client: Client, pub client: Client,
#[serde(skip)]
pub gateway_options: GatewayOptions,
} }
impl PartialEq for Instance { impl PartialEq for Instance {
@ -104,6 +106,7 @@ impl Instance {
instance_info: GeneralConfiguration::default(), instance_info: GeneralConfiguration::default(),
limits_information: limit_information, limits_information: limit_information,
client: Client::new(), client: Client::new(),
gateway_options: GatewayOptions::default(),
}; };
instance.instance_info = match instance.general_configuration_schema().await { instance.instance_info = match instance.general_configuration_schema().await {
Ok(schema) => schema, Ok(schema) => schema,
@ -139,6 +142,13 @@ impl Instance {
Err(_) => Ok(None), Err(_) => Ok(None),
} }
} }
/// Sets the [`GatewayOptions`] the instance will use when spawning new connections.
///
/// These options are used on the gateways created when logging in and registering.
pub fn set_gateway_options(&mut self, options: GatewayOptions) {
self.gateway_options = options;
}
} }
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
@ -215,7 +225,9 @@ impl ChorusUser {
let object = Arc::new(RwLock::new(User::default())); let object = Arc::new(RwLock::new(User::default()));
let wss_url = instance.read().unwrap().urls.wss.clone(); let wss_url = instance.read().unwrap().urls.wss.clone();
// Dummy gateway object // Dummy gateway object
let gateway = Gateway::spawn(wss_url).await.unwrap(); let gateway = Gateway::spawn(wss_url, GatewayOptions::default())
.await
.unwrap();
ChorusUser { ChorusUser {
token, token,
belongs_to: instance.clone(), belongs_to: instance.clone(),

View File

@ -4,7 +4,6 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_aux::prelude::deserialize_string_from_number;
use serde_repr::{Deserialize_repr, Serialize_repr}; use serde_repr::{Deserialize_repr, Serialize_repr};
use std::fmt::Debug; use std::fmt::Debug;

View File

@ -4,11 +4,9 @@
use crate::types::utils::Snowflake; use crate::types::utils::Snowflake;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde::{Deserialize, Serialize};
use serde_aux::prelude::deserialize_option_number_from_string; use serde_aux::prelude::deserialize_option_number_from_string;
use std::fmt::Debug; use std::fmt::Debug;
use std::num::ParseIntError;
use std::str::FromStr;
#[cfg(feature = "client")] #[cfg(feature = "client")]
use crate::gateway::Updateable; use crate::gateway::Updateable;

View File

@ -20,7 +20,9 @@ pub struct UpdatePresence {
#[derive(Debug, Deserialize, Serialize, Default, Clone, PartialEq, WebSocketEvent)] #[derive(Debug, Deserialize, Serialize, Default, Clone, PartialEq, WebSocketEvent)]
/// Received to tell the client that a user updated their presence / status /// Received to tell the client that a user updated their presence / status
///
/// See <https://discord.com/developers/docs/topics/gateway-events#presence-update-presence-update-event-fields> /// See <https://discord.com/developers/docs/topics/gateway-events#presence-update-presence-update-event-fields>
/// (Same structure as <https://docs.discord.sex/resources/presence#presence-object>)
pub struct PresenceUpdate { pub struct PresenceUpdate {
pub user: PublicUser, pub user: PublicUser,
#[serde(default)] #[serde(default)]

View File

@ -6,13 +6,14 @@ use serde::{Deserialize, Serialize};
use crate::types::entities::{Guild, User}; use crate::types::entities::{Guild, User};
use crate::types::events::{Session, WebSocketEvent}; use crate::types::events::{Session, WebSocketEvent};
use crate::types::interfaces::ClientStatusObject; use crate::types::{Activity, Channel, ClientStatusObject, GuildMember, PresenceUpdate, Snowflake, VoiceState};
use crate::types::{Activity, GuildMember, PresenceUpdate, VoiceState};
#[derive(Debug, Deserialize, Serialize, Default, Clone, WebSocketEvent)] #[derive(Debug, Deserialize, Serialize, Default, Clone, WebSocketEvent)]
/// 1/2 half documented; /// 1/2 officially documented;
/// Received after identifying, provides initial user info; /// Received after identifying, provides initial user info;
/// See <https://discord.com/developers/docs/topics/gateway-events#ready;> ///
/// See <https://docs.discord.sex/topics/gateway-events#ready> and <https://discord.com/developers/docs/topics/gateway-events#ready>
// TODO: There are a LOT of fields missing here
pub struct GatewayReady { pub struct GatewayReady {
pub analytics_token: Option<String>, pub analytics_token: Option<String>,
pub auth_session_id_hash: Option<String>, pub auth_session_id_hash: Option<String>,
@ -32,36 +33,47 @@ pub struct GatewayReady {
#[derive(Debug, Deserialize, Serialize, Default, Clone, WebSocketEvent)] #[derive(Debug, Deserialize, Serialize, Default, Clone, WebSocketEvent)]
/// Officially Undocumented; /// Officially Undocumented;
/// Sent after the READY event when a client is a user, seems to somehow add onto the ready event; /// Sent after the READY event when a client is a user,
/// seems to somehow add onto the ready event;
///
/// See <https://docs.discord.sex/topics/gateway-events#ready-supplemental>
pub struct GatewayReadySupplemental { pub struct GatewayReadySupplemental {
/// The presences of the user's relationships and guild presences sent at startup
pub merged_presences: MergedPresences, pub merged_presences: MergedPresences,
pub merged_members: Vec<Vec<GuildMember>>, pub merged_members: Vec<Vec<GuildMember>>,
// ? pub lazy_private_channels: Vec<Channel>,
pub lazy_private_channels: Vec<serde_json::Value>,
pub guilds: Vec<SupplementalGuild>, pub guilds: Vec<SupplementalGuild>,
// ? pomelo // "Upcoming changes that the client should disclose to the user" (discord.sex)
pub disclose: Vec<String>, pub disclose: Vec<String>,
} }
#[derive(Debug, Deserialize, Serialize, Default, Clone)] #[derive(Debug, Deserialize, Serialize, Default, Clone)]
/// See <https://docs.discord.sex/topics/gateway-events#merged-presences-structure>
pub struct MergedPresences { pub struct MergedPresences {
/// "Presences of the user's guilds in the same order as the guilds array in ready"
/// (discord.sex)
pub guilds: Vec<Vec<MergedPresenceGuild>>, pub guilds: Vec<Vec<MergedPresenceGuild>>,
/// "Presences of the user's friends and implicit relationships" (discord.sex)
pub friends: Vec<MergedPresenceFriend>, pub friends: Vec<MergedPresenceFriend>,
} }
#[derive(Debug, Deserialize, Serialize, Default, Clone)] #[derive(Debug, Deserialize, Serialize, Default, Clone)]
/// Not documented even unofficially
pub struct MergedPresenceFriend { pub struct MergedPresenceFriend {
pub user_id: String, pub user_id: Snowflake,
pub status: String, pub status: String,
/// Looks like ms?? // Looks like ms??
pub last_modified: u128, //
// Not always sent
pub last_modified: Option<u128>,
pub client_status: ClientStatusObject, pub client_status: ClientStatusObject,
pub activities: Vec<Activity>, pub activities: Vec<Activity>,
} }
#[derive(Debug, Deserialize, Serialize, Default, Clone)] #[derive(Debug, Deserialize, Serialize, Default, Clone)]
/// Not documented even unofficially
pub struct MergedPresenceGuild { pub struct MergedPresenceGuild {
pub user_id: String, pub user_id: Snowflake,
pub status: String, pub status: String,
// ? // ?
pub game: Option<serde_json::Value>, pub game: Option<serde_json::Value>,
@ -70,8 +82,10 @@ pub struct MergedPresenceGuild {
} }
#[derive(Debug, Deserialize, Serialize, Default, Clone)] #[derive(Debug, Deserialize, Serialize, Default, Clone)]
/// See <https://docs.discord.sex/topics/gateway-events#supplemental-guild-structure>
pub struct SupplementalGuild { pub struct SupplementalGuild {
pub id: Snowflake,
pub voice_states: Option<Vec<VoiceState>>, pub voice_states: Option<Vec<VoiceState>>,
pub id: String, /// Field not documented even unofficially
pub embedded_activities: Vec<serde_json::Value>, pub embedded_activities: Vec<serde_json::Value>,
} }

View File

@ -3,10 +3,9 @@
// file, You can obtain one at http://mozilla.org/MPL/2.0/. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
use bitflags::bitflags; use bitflags::bitflags;
use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde::{Deserialize, Serialize};
use serde::de::Visitor;
use crate::types::{ChannelType, DefaultReaction, Error, entities::PermissionOverwrite, Snowflake}; use crate::types::{ChannelType, DefaultReaction, entities::PermissionOverwrite, Snowflake};
#[derive(Debug, Deserialize, Serialize, Default, PartialEq, PartialOrd)] #[derive(Debug, Deserialize, Serialize, Default, PartialEq, PartialOrd)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]

View File

@ -4,24 +4,91 @@
use std::collections::HashMap; use std::collections::HashMap;
use chrono::NaiveDate;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::types::Snowflake; use crate::types::Snowflake;
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
/// A schema used to modify a user. /// A schema used to modify a user.
///
/// See <https://docs.discord.sex/resources/user#json-params>
pub struct UserModifySchema { pub struct UserModifySchema {
/// The user's new username (2-32 characters)
///
/// Requires that `current_password` is set.
pub username: Option<String>, pub username: Option<String>,
// TODO: Maybe add a special discriminator type?
/// Requires that `current_password` is set.
pub discriminator: Option<String>,
/// The user's display name (1-32 characters)
///
/// # Note
///
/// This is not yet implemented on Spacebar
pub global_name: Option<String>,
// TODO: Add a CDN data type
pub avatar: Option<String>, pub avatar: Option<String>,
pub bio: Option<String>, /// Note: This is not yet implemented on Spacebar
pub accent_color: Option<u64>, pub avatar_decoration_id: Option<Snowflake>,
pub banner: Option<String>, /// Note: This is not yet implemented on Spacebar
pub current_password: Option<String>, pub avatar_decoration_sku_id: Option<Snowflake>,
pub new_password: Option<String>, /// The user's email address; if changing from a verified email, email_token must be provided
pub code: Option<String>, ///
/// Requires that `current_password` is set.
// TODO: Is ^ up to date? One would think this may not be the case, since email_token exists
pub email: Option<String>, pub email: Option<String>,
pub discriminator: Option<i16>, /// The user's email token from their previous email, required if a new email is set.
///
/// See <https://docs.discord.sex/resources/user#modify-user-email> and <https://docs.discord.sex/resources/user#verify-user-email-change>
/// for changing the user's email.
///
/// # Note
///
/// This is not yet implemented on Spacebar
pub email_token: Option<String>,
/// The user's pronouns (max 40 characters)
///
/// # Note
///
/// This is not yet implemented on Spacebar
pub pronouns: Option<String>,
/// The user's banner.
///
/// Can only be changed for premium users
pub banner: Option<String>,
/// The user's bio (max 190 characters)
pub bio: Option<String>,
/// The user's accent color, as a hex integer
pub accent_color: Option<u64>,
/// The user's [UserFlags].
///
/// Only [UserFlags::PREMIUM_PROMO_DISMISSED], [UserFlags::HAS_UNREAD_URGENT_MESSAGES]
/// and DISABLE_PREMIUM can be set.
///
/// # Note
///
/// This is not yet implemented on Spacebar
pub flags: Option<u64>,
/// The user's date of birth, can only be set once
///
/// Requires that `current_password` is set.
pub date_of_birth: Option<NaiveDate>,
/// The user's current password (if the account does not have a password, this sets it)
///
/// Required for updating `username`, `discriminator`, `email`, `date_of_birth` and
/// `new_password`
#[serde(rename = "password")]
pub current_password: Option<String>,
/// The user's new password (8-72 characters)
///
/// Requires that `current_password` is set.
///
/// Regenerates the user's token
pub new_password: Option<String>,
/// Spacebar only field, potentially same as `email_token`
pub code: Option<String>,
} }
/// A schema used to create a private channel. /// A schema used to create a private channel.
@ -33,7 +100,7 @@ pub struct UserModifySchema {
/// ///
/// # Reference: /// # Reference:
/// Read: <https://discord-userdoccers.vercel.app/resources/channel#json-params> /// Read: <https://discord-userdoccers.vercel.app/resources/channel#json-params>
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct PrivateChannelCreateSchema { pub struct PrivateChannelCreateSchema {
pub recipients: Option<Vec<Snowflake>>, pub recipients: Option<Vec<Snowflake>>,
pub access_tokens: Option<Vec<String>>, pub access_tokens: Option<Vec<String>>,

View File

@ -23,3 +23,4 @@ impl From<WsMessage> for VoiceGatewayMessage {
} }
} }
} }

View File

@ -44,7 +44,7 @@ impl VoiceGateway {
pub async fn spawn(websocket_url: String) -> Result<VoiceGatewayHandle, VoiceGatewayError> { pub async fn spawn(websocket_url: String) -> Result<VoiceGatewayHandle, VoiceGatewayError> {
// Append the needed things to the websocket url // Append the needed things to the websocket url
let processed_url = format!("wss://{}/?v=7", websocket_url); let processed_url = format!("wss://{}/?v=7", websocket_url);
trace!("Created voice socket url: {}", processed_url.clone()); trace!("VGW: Connecting to {}", processed_url.clone());
let (websocket_send, mut websocket_receive) = let (websocket_send, mut websocket_receive) =
match WebSocketBackend::connect(&processed_url).await { match WebSocketBackend::connect(&processed_url).await {

View File

@ -4,7 +4,7 @@
use std::str::FromStr; use std::str::FromStr;
use chorus::gateway::Gateway; use chorus::gateway::{Gateway, GatewayOptions};
use chorus::types::IntoShared; use chorus::types::IntoShared;
use chorus::{ use chorus::{
instance::{ChorusUser, Instance}, instance::{ChorusUser, Instance},
@ -50,7 +50,7 @@ impl TestBundle {
limits: self.user.limits.clone(), limits: self.user.limits.clone(),
settings: self.user.settings.clone(), settings: self.user.settings.clone(),
object: self.user.object.clone(), object: self.user.object.clone(),
gateway: Gateway::spawn(self.instance.urls.wss.clone()) gateway: Gateway::spawn(self.instance.urls.wss.clone(), GatewayOptions::default())
.await .await
.unwrap(), .unwrap(),
} }
@ -59,6 +59,10 @@ impl TestBundle {
// Set up a test by creating an Instance and a User. Reduces Test boilerplate. // Set up a test by creating an Instance and a User. Reduces Test boilerplate.
pub(crate) async fn setup() -> TestBundle { pub(crate) async fn setup() -> TestBundle {
// So we can get logs when tests fail
let _ = simple_logger::SimpleLogger::with_level(simple_logger::SimpleLogger::new(), log::LevelFilter::Debug).init();
let instance = Instance::new("http://localhost:3001/api").await.unwrap(); let instance = Instance::new("http://localhost:3001/api").await.unwrap();
// Requires the existence of the below user. // Requires the existence of the below user.
let reg = RegisterSchema { let reg = RegisterSchema {
@ -119,7 +123,7 @@ pub(crate) async fn setup() -> TestBundle {
let urls = UrlBundle::new( let urls = UrlBundle::new(
"http://localhost:3001/api".to_string(), "http://localhost:3001/api".to_string(),
"http://localhost:3001/api".to_string(), "http://localhost:3001/api".to_string(),
"ws://localhost:3001".to_string(), "ws://localhost:3001/".to_string(),
"http://localhost:3001".to_string(), "http://localhost:3001".to_string(),
); );
TestBundle { TestBundle {

View File

@ -30,7 +30,7 @@ use wasmtimer::tokio::sleep;
async fn test_gateway_establish() { async fn test_gateway_establish() {
let bundle = common::setup().await; let bundle = common::setup().await;
let _: GatewayHandle = Gateway::spawn(bundle.urls.wss.clone()).await.unwrap(); let _: GatewayHandle = Gateway::spawn(bundle.urls.wss.clone(), GatewayOptions::default()).await.unwrap();
common::teardown(bundle).await common::teardown(bundle).await
} }
@ -52,7 +52,7 @@ impl Observer<GatewayReady> for GatewayReadyObserver {
async fn test_gateway_authenticate() { async fn test_gateway_authenticate() {
let bundle = common::setup().await; let bundle = common::setup().await;
let gateway: GatewayHandle = Gateway::spawn(bundle.urls.wss.clone()).await.unwrap(); let gateway: GatewayHandle = Gateway::spawn(bundle.urls.wss.clone(), GatewayOptions::default()).await.unwrap();
let (ready_send, mut ready_receive) = tokio::sync::mpsc::channel(1); let (ready_send, mut ready_receive) = tokio::sync::mpsc::channel(1);
@ -79,7 +79,7 @@ async fn test_gateway_authenticate() {
println!("Timed out waiting for event, failing.."); println!("Timed out waiting for event, failing..");
assert!(false); assert!(false);
} }
// Sucess, we have received it // Success, we have received it
Some(_) = ready_receive.recv() => {} Some(_) = ready_receive.recv() => {}
}; };