diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index b192a3c..31962c2 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -4,7 +4,7 @@ on: push: branches: [ "main", "dev" ] pull_request: - branches: [ "main" ] + branches: [ "main", "dev" ] env: CARGO_TERM_COLOR: always @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Clone spacebar server run: | git clone https://github.com/bitfl0wer/server.git diff --git a/.github/workflows/rust-clippy.yml b/.github/workflows/rust-clippy.yml index 0c3840f..15dedd4 100644 --- a/.github/workflows/rust-clippy.yml +++ b/.github/workflows/rust-clippy.yml @@ -14,7 +14,7 @@ on: branches: [ "main", "preserve/*", "dev" ] pull_request: # The branches below must be a subset of the branches above - branches: [ "main" ] + branches: [ "main", "dev" ] jobs: rust-clippy-analyze: @@ -26,10 +26,10 @@ jobs: actions: read # only required for a private repository by github/codeql-action/upload-sarif to get the Action run status steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Install Rust toolchain - uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af #@v1 + uses: actions-rs/toolchain@v1 with: profile: minimal toolchain: stable diff --git a/Cargo.lock b/Cargo.lock index bac7fc2..2834823 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -178,7 +178,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chorus" -version = "0.9.0" +version = "0.11.0" dependencies = [ "async-trait", "base64 0.21.5", @@ -193,11 +193,12 @@ dependencies = [ "jsonwebtoken", "lazy_static", "log", - "native-tls", - "openssl", "poem", + "rand", "regex", "reqwest", + "rustls", + "rustls-native-certs", "rusty-hook", "serde", "serde-aux", @@ -963,7 +964,7 @@ checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378" dependencies = [ "base64 0.21.5", "pem", - "ring", + "ring 0.16.20", "serde", "serde_json", "simple_asn1", @@ -1584,11 +1585,25 @@ dependencies = [ "libc", "once_cell", "spin 0.5.2", - "untrusted", + "untrusted 0.7.1", "web-sys", "winapi", ] +[[package]] +name = "ring" +version = "0.17.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b" +dependencies = [ + "cc", + "getrandom", + "libc", + "spin 0.9.8", + "untrusted 0.9.0", + "windows-sys", +] + [[package]] name = "rsa" version = "0.9.3" @@ -1617,9 +1632,9 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustix" -version = "0.38.21" +version = "0.38.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b426b0506e5d50a7d8dafcf2e81471400deb602392c7dd110815afb4eaf02a3" +checksum = "80109a168d9bc0c7f483083244543a6eb0dba02295d33ca268145e6190d6df0c" dependencies = [ "bitflags 2.4.1", "errno", @@ -1628,6 +1643,49 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "rustls" +version = "0.21.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "446e14c5cda4f3f30fe71863c34ec70f5ac79d6087097ad0bb433e1be5edf04c" +dependencies = [ + "log", + "ring 0.17.5", + "rustls-webpki", + "sct", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" +dependencies = [ + "base64 0.21.5", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring 0.17.5", + "untrusted 0.9.0", +] + [[package]] name = "rusty-hook" version = "0.11.2" @@ -1661,6 +1719,16 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sct" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +dependencies = [ + "ring 0.16.20", + "untrusted 0.7.1", +] + [[package]] name = "security-framework" version = "2.9.2" @@ -2287,6 +2355,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.14" @@ -2306,9 +2384,10 @@ checksum = "212d5dcb2a1ce06d81107c3d0ffa3121fe974b73f068c8282cb1c32328113b6c" dependencies = [ "futures-util", "log", - "native-tls", + "rustls", + "rustls-native-certs", "tokio", - "tokio-native-tls", + "tokio-rustls", "tungstenite", ] @@ -2408,8 +2487,8 @@ dependencies = [ "http", "httparse", "log", - "native-tls", "rand", + "rustls", "sha1", "thiserror", "url", @@ -2485,6 +2564,12 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.4.1" diff --git a/Cargo.toml b/Cargo.toml index d824218..ecdb3ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "chorus" description = "A library for interacting with multiple Spacebar-compatible Instances at once." -version = "0.9.0" +version = "0.11.0" license = "AGPL-3.0" edition = "2021" repository = "https://github.com/polyphony-chat/chorus" @@ -27,11 +27,12 @@ url = "2.4.0" chrono = { version = "0.4.26", features = ["serde"] } regex = "1.9.4" custom_error = "1.9.2" -native-tls = "0.2.11" -tokio-tungstenite = { version = "0.20.0", features = ["native-tls"] } +tokio-tungstenite = { version = "0.20.1", features = [ + "rustls-tls-native-roots", + "rustls-native-certs", +] } futures-util = "0.3.28" http = "0.2.9" -openssl = "0.10.56" base64 = "0.21.3" hostname = "0.3.1" bitflags = { version = "2.4.0", features = ["serde"] } @@ -52,6 +53,9 @@ log = "0.4.20" async-trait = "0.1.73" chorus-macros = "0.2.0" discortp = { version = "0.5.0", optional = true, features = ["rtp", "discord", "demux"] } +rustls = "0.21.8" +rustls-native-certs = "0.6.3" +rand = "0.8.5" [dev-dependencies] tokio = { version = "1.32.0", features = ["full"] } diff --git a/README.md b/README.md index 74dff4c..e654610 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,6 @@ [![Build][build-shield]][build-url] [![Coverage][coverage-shield]][coverage-url] [![Contributors][contributors-shield]][contributors-url] -[![Forks][forks-shield]][forks-url] -[![Issues][issues-shield]][issues-url]
@@ -17,7 +15,6 @@

Chorus

- A rust library for interacting with (multiple) Spacebar-compatible APIs and Gateways (at the same time).
Explore the docs ยป
@@ -32,16 +29,95 @@ -## About +Chorus is a Rust library that allows developers to interact with multiple Spacebar-compatible APIs and Gateways (Including +Discord.com) simultaneously. The library provides a simple and efficient way to communicate with these services, making it easier for developers to build applications that rely on them. Chorus is open-source and welcomes contributions from the community. -Chorus is a Rust library that allows developers to interact with multiple Spacebar-compatible APIs and Gateways simultaneously. The library provides a simple and efficient way to communicate with these services, making it easier for developers to build applications that rely on them. Chorus is open-source and welcomes contributions from the community. +## A Tour of Chorus + +Chorus combines all the required functionalities of a user-centric Spacebar library into one package. The library +handles a lot of things for you, such as rate limiting, authentication, and more. This means that you can focus on +building your application, instead of worrying about the underlying implementation details. + +To get started with Chorus, import it into your project by adding the following to your `Cargo.toml` file: + +```toml +[dependencies] +chorus = "0" +``` + +### Establishing a Connection + +To connect to a Spacebar compatible server, you need to create an [`Instance`](https://docs.rs/chorus/latest/chorus/instance/struct.Instance.html) like this: + +```rs +use chorus::instance::Instance; +use chorus::UrlBundle; + +#[tokio::main] +async fn main() { + let bundle = UrlBundle::new( + "https://example.com/api".to_string(), + "wss://example.com/".to_string(), + "https://example.com/cdn".to_string(), + ); + let instance = Instance::new(bundle, true) + .await + .expect("Failed to connect to the Spacebar server"); + // You can create as many instances of `Instance` as you want, but each `Instance` should likely be unique. + dbg!(instance.instance_info); + dbg!(instance.limits_information); +} +``` + +This Instance can now be used to log in, register and from there on, interact with the server in all sorts of ways. + +### Logging In + +Logging in correctly provides you with an instance of `ChorusUser`, with which you can interact with the server and +manipulate the account. Assuming you already have an account on the server, you can log in like this: + +```rs +use chorus::types::LoginSchema; +// Assume, you already have an account created on this instance. Registering an account works +// the same way, but you'd use the Register-specific Structs and methods instead. +let login_schema = LoginSchema { + login: "user@example.com".to_string(), + password: "Correct-Horse-Battery-Staple".to_string(), + ..Default::default() +}; +// Each user connects to the Gateway. The Gateway connection lives on a seperate thread. Depending on +// the runtime feature you choose, this can potentially take advantage of all of your computers' threads. +let user = instance + .login_account(login_schema) + .await + .expect("An error occurred during the login process"); +dbg!(user.belongs_to); +dbg!(&user.object.read().unwrap().username); +``` + +## Supported Platforms + +All major desktop operating systems (Windows, macOS (aarch64/x86_64), Linux (aarch64/x86_64)) are supported. +We are currently working on adding full support for `wasm32-unknown-unknown`. This will allow you to use Chorus in +your browser, or in any other environment that supports WebAssembly. + +We recommend checking out the examples directory, as well as the documentation for more information. + +## MSRV (Minimum Supported Rust Version) + +Rust **1.67.1**. This number might change at any point while Chorus is not yet at version 1.0.0. + +## Versioning + +This crate uses Semantic Versioning 2.0.0 as its versioning scheme. You can read the specification [here](https://semver.org/spec/v2.0.0.html). ## Contributing +Chorus is currently missing voice support and a lot of API endpoints, many of which should be trivial to implement, +ever since [we streamlined the process of doing so](https://github.com/polyphony-chat/chorus/discussions/401). + If you'd like to contribute new functionality, check out [The 'Meta'-issues.](https://github.com/polyphony-chat/chorus/issues?q=is%3Aissue+label%3A%22Type%3A+Meta%22+) They contain a comprehensive list of all features which are yet missing for full Discord.com compatibility. -If you would like to contribute, please feel free to open an Issue with the idea you have, or a -Pull Request. Please keep our [contribution guidelines](https://github.com/polyphony-chat/.github/blob/main/CONTRIBUTION_GUIDELINES.md) in mind. Your contribution might not be -accepted, if it violates these guidelines or [our Code of Conduct](https://github.com/polyphony-chat/.github/blob/main/CODE_OF_CONDUCT.md). +Please feel free to open an Issue with the idea you have, or a Pull Request. Please keep our [contribution guidelines](https://github.com/polyphony-chat/.github/blob/main/CONTRIBUTION_GUIDELINES.md) in mind. Your contribution might not be accepted if it violates these guidelines or [our Code of Conduct](https://github.com/polyphony-chat/.github/blob/main/CODE_OF_CONDUCT.md).

Progress Tracker/Roadmap diff --git a/examples/instance.rs b/examples/instance.rs new file mode 100644 index 0000000..337482b --- /dev/null +++ b/examples/instance.rs @@ -0,0 +1,16 @@ +use chorus::instance::Instance; +use chorus::UrlBundle; + +#[tokio::main] +async fn main() { + let bundle = UrlBundle::new( + "https://example.com/api".to_string(), + "wss://example.com/".to_string(), + "https://example.com/cdn".to_string(), + ); + let instance = Instance::new(bundle, true) + .await + .expect("Failed to connect to the Spacebar server"); + dbg!(instance.instance_info); + dbg!(instance.limits_information); +} diff --git a/examples/login.rs b/examples/login.rs new file mode 100644 index 0000000..4595a06 --- /dev/null +++ b/examples/login.rs @@ -0,0 +1,30 @@ +use chorus::instance::Instance; +use chorus::types::LoginSchema; +use chorus::UrlBundle; + +#[tokio::main] +async fn main() { + let bundle = UrlBundle::new( + "https://example.com/api".to_string(), + "wss://example.com/".to_string(), + "https://example.com/cdn".to_string(), + ); + let instance = Instance::new(bundle, true) + .await + .expect("Failed to connect to the Spacebar server"); + // Assume, you already have an account created on this instance. Registering an account works + // the same way, but you'd use the Register-specific Structs and methods instead. + let login_schema = LoginSchema { + login: "user@example.com".to_string(), + password: "Correct-Horse-Battery-Staple".to_string(), + ..Default::default() + }; + // Each user connects to the Gateway. The Gateway connection lives on a seperate thread. Depending on + // the runtime feature you choose, this can potentially take advantage of all of your computers' threads. + let user = instance + .login_account(login_schema) + .await + .expect("An error occurred during the login process"); + dbg!(user.belongs_to); + dbg!(&user.object.read().unwrap().username); +} diff --git a/src/api/invites/mod.rs b/src/api/invites/mod.rs index 332570b..80b47d2 100644 --- a/src/api/invites/mod.rs +++ b/src/api/invites/mod.rs @@ -28,11 +28,11 @@ impl ChorusUser { .header("Authorization", self.token()), limit_type: LimitType::Global, }; - if session_id.is_some() { + if let Some(session_id) = session_id { request.request = request .request .header("Content-Type", "application/json") - .body(to_string(session_id.unwrap()).unwrap()); + .body(to_string(session_id).unwrap()); } request.deserialize_response::(self).await } diff --git a/src/gateway.rs b/src/gateway/gateway.rs similarity index 54% rename from src/gateway.rs rename to src/gateway/gateway.rs index def0df1..30d0610 100644 --- a/src/gateway.rs +++ b/src/gateway/gateway.rs @@ -1,332 +1,10 @@ -//! Gateway connection, communication and handling, as well as object caching and updating. - -use crate::errors::GatewayError; -use crate::gateway::events::Events; +use self::event::Events; +use super::*; use crate::types::{ self, AutoModerationRule, AutoModerationRuleUpdate, Channel, ChannelCreate, ChannelDelete, - ChannelUpdate, Composite, Guild, GuildRoleCreate, GuildRoleUpdate, JsonField, RoleObject, - Snowflake, SourceUrlField, ThreadUpdate, UpdateMessage, WebSocketEvent, + ChannelUpdate, Guild, GuildRoleCreate, GuildRoleUpdate, JsonField, RoleObject, SourceUrlField, + ThreadUpdate, UpdateMessage, WebSocketEvent, }; -use async_trait::async_trait; -use std::any::Any; -use std::collections::HashMap; -use std::fmt::Debug; -use std::sync::{Arc, RwLock}; -use std::time::Duration; -use tokio::time::sleep_until; - -use futures_util::stream::SplitSink; -use futures_util::stream::SplitStream; -use futures_util::SinkExt; -use futures_util::StreamExt; -use log::{info, trace, warn}; -use native_tls::TlsConnector; -use tokio::net::TcpStream; -use tokio::sync::mpsc::Sender; -use tokio::sync::Mutex; -use tokio::task; -use tokio::task::JoinHandle; -use tokio::time; -use tokio::time::Instant; -use tokio_tungstenite::MaybeTlsStream; -use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream}; - -// Gateway opcodes -/// Opcode received when the server dispatches a [crate::types::WebSocketEvent] -const GATEWAY_DISPATCH: u8 = 0; -/// Opcode sent when sending a heartbeat -const GATEWAY_HEARTBEAT: u8 = 1; -/// Opcode sent to initiate a session -/// -/// See [types::GatewayIdentifyPayload] -const GATEWAY_IDENTIFY: u8 = 2; -/// Opcode sent to update our presence -/// -/// See [types::GatewayUpdatePresence] -const GATEWAY_UPDATE_PRESENCE: u8 = 3; -/// Opcode sent to update our state in vc -/// -/// Like muting, deafening, leaving, joining.. -/// -/// See [types::UpdateVoiceState] -const GATEWAY_UPDATE_VOICE_STATE: u8 = 4; -/// Opcode sent to resume a session -/// -/// See [types::GatewayResume] -const GATEWAY_RESUME: u8 = 6; -/// Opcode received to tell the client to reconnect -const GATEWAY_RECONNECT: u8 = 7; -/// Opcode sent to request guild member data -/// -/// See [types::GatewayRequestGuildMembers] -const GATEWAY_REQUEST_GUILD_MEMBERS: u8 = 8; -/// Opcode received to tell the client their token / session is invalid -const GATEWAY_INVALID_SESSION: u8 = 9; -/// Opcode received when initially connecting to the gateway, starts our heartbeat -/// -/// See [types::HelloData] -const GATEWAY_HELLO: u8 = 10; -/// Opcode received to acknowledge a heartbeat -const GATEWAY_HEARTBEAT_ACK: u8 = 11; -/// Opcode sent to get the voice state of users in a given DM/group channel -/// -/// See [types::CallSync] -const GATEWAY_CALL_SYNC: u8 = 13; -/// Opcode sent to get data for a server (Lazy Loading request) -/// -/// Sent by the official client when switching to a server -/// -/// See [types::LazyRequest] -const GATEWAY_LAZY_REQUEST: u8 = 14; - -/// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms -pub(crate) const HEARTBEAT_ACK_TIMEOUT: u64 = 2000; - -/// Represents a messsage received from the gateway. This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError]. -/// This struct is used internally when handling messages. -#[derive(Clone, Debug)] -pub struct GatewayMessage { - /// The message we received from the server - message: tokio_tungstenite::tungstenite::Message, -} - -impl GatewayMessage { - /// Creates self from a tungstenite message - pub fn from_tungstenite_message(message: tokio_tungstenite::tungstenite::Message) -> Self { - Self { message } - } - - /// Parses the message as an error; - /// Returns the error if succesfully parsed, None if the message isn't an error - pub fn error(&self) -> Option { - let content = self.message.to_string(); - - // Some error strings have dots on the end, which we don't care about - let processed_content = content.to_lowercase().replace('.', ""); - - match processed_content.as_str() { - "unknown error" | "4000" => Some(GatewayError::Unknown), - "unknown opcode" | "4001" => Some(GatewayError::UnknownOpcode), - "decode error" | "error while decoding payload" | "4002" => Some(GatewayError::Decode), - "not authenticated" | "4003" => Some(GatewayError::NotAuthenticated), - "authentication failed" | "4004" => Some(GatewayError::AuthenticationFailed), - "already authenticated" | "4005" => Some(GatewayError::AlreadyAuthenticated), - "invalid seq" | "4007" => Some(GatewayError::InvalidSequenceNumber), - "rate limited" | "4008" => Some(GatewayError::RateLimited), - "session timed out" | "4009" => Some(GatewayError::SessionTimedOut), - "invalid shard" | "4010" => Some(GatewayError::InvalidShard), - "sharding required" | "4011" => Some(GatewayError::ShardingRequired), - "invalid api version" | "4012" => Some(GatewayError::InvalidAPIVersion), - "invalid intent(s)" | "invalid intent" | "4013" => Some(GatewayError::InvalidIntents), - "disallowed intent(s)" | "disallowed intents" | "4014" => { - Some(GatewayError::DisallowedIntents) - } - _ => None, - } - } - - /// Returns whether or not the message is an error - pub fn is_error(&self) -> bool { - self.error().is_some() - } - - /// Parses the message as a payload; - /// Returns a result of deserializing - pub fn payload(&self) -> Result { - return serde_json::from_str(self.message.to_text().unwrap()); - } - - /// Returns whether or not the message is a payload - pub fn is_payload(&self) -> bool { - // close messages are never payloads, payloads are only text messages - if self.message.is_close() | !self.message.is_text() { - return false; - } - - return self.payload().is_ok(); - } - - /// Returns whether or not the message is empty - pub fn is_empty(&self) -> bool { - self.message.is_empty() - } -} - -pub type ObservableObject = dyn Send + Sync + Any; - -/// Represents a handle to a Gateway connection. A Gateway connection will create observable -/// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently -/// implemented types with the trait [`WebSocketEvent`] -/// Using this handle you can also send Gateway Events directly. -#[derive(Debug, Clone)] -pub struct GatewayHandle { - pub url: String, - pub events: Arc>, - pub websocket_send: Arc< - Mutex< - SplitSink< - WebSocketStream>, - tokio_tungstenite::tungstenite::Message, - >, - >, - >, - /// Tells gateway tasks to close - kill_send: tokio::sync::broadcast::Sender<()>, - pub(crate) store: Arc>>>>, -} - -/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake. -pub trait Updateable: 'static + Send + Sync { - fn id(&self) -> Snowflake; -} - -impl GatewayHandle { - /// Sends json to the gateway with an opcode - async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value) { - let gateway_payload = types::GatewaySendPayload { - op_code, - event_data: Some(to_send), - sequence_number: None, - }; - - let payload_json = serde_json::to_string(&gateway_payload).unwrap(); - - let message = tokio_tungstenite::tungstenite::Message::text(payload_json); - - self.websocket_send - .lock() - .await - .send(message) - .await - .unwrap(); - } - - pub async fn observe>( - &self, - object: Arc>, - ) -> Arc> { - let mut store = self.store.lock().await; - let id = object.read().unwrap().id(); - if let Some(channel) = store.get(&id) { - let object = channel.clone(); - drop(store); - object - .read() - .unwrap() - .downcast_ref::() - .unwrap_or_else(|| { - panic!( - "Snowflake {} already exists in the store, but it is not of type T.", - id - ) - }); - let ptr = Arc::into_raw(object.clone()); - // SAFETY: - // - We have just checked that the typeid of the `dyn Any ...` matches that of `T`. - // - This operation doesn't read or write any shared data, and thus cannot cause a data race - // - The reference count is not being modified - let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock).clone() }; - let object = downcasted.read().unwrap().clone(); - - let watched_object = object.watch_whole(self).await; - *downcasted.write().unwrap() = watched_object; - downcasted - } else { - let id = object.read().unwrap().id(); - let object = object.read().unwrap().clone(); - let object = object.clone().watch_whole(self).await; - let wrapped = Arc::new(RwLock::new(object)); - store.insert(id, wrapped.clone()); - wrapped - } - } - - /// Recursively observes and updates all updateable fields on the struct T. Returns an object `T` - /// with all of its observable fields being observed. - pub async fn observe_and_into_inner>( - &self, - object: Arc>, - ) -> T { - let channel = self.observe(object.clone()).await; - let object = channel.read().unwrap().clone(); - object - } - - /// Sends an identify event to the gateway - pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Identify.."); - - self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await; - } - - /// Sends a resume event to the gateway - pub async fn send_resume(&self, to_send: types::GatewayResume) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Resume.."); - - self.send_json_event(GATEWAY_RESUME, to_send_value).await; - } - - /// Sends an update presence event to the gateway - pub async fn send_update_presence(&self, to_send: types::UpdatePresence) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Update Presence.."); - - self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value) - .await; - } - - /// Sends a request guild members to the server - pub async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Request Guild Members.."); - - self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value) - .await; - } - - /// Sends an update voice state to the server - pub async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) { - let to_send_value = serde_json::to_value(to_send).unwrap(); - - trace!("GW: Sending Update Voice State.."); - - self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value) - .await; - } - - /// Sends a call sync to the server - pub async fn send_call_sync(&self, to_send: types::CallSync) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Call Sync.."); - - self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await; - } - - /// Sends a Lazy Request - pub async fn send_lazy_request(&self, to_send: types::LazyRequest) { - let to_send_value = serde_json::to_value(&to_send).unwrap(); - - trace!("GW: Sending Lazy Request.."); - - self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value) - .await; - } - - /// Closes the websocket connection and stops all gateway tasks; - /// - /// Esentially pulls the plug on the gateway, leaving it possible to resume; - pub async fn close(&self) { - self.kill_send.send(()).unwrap(); - self.websocket_send.lock().await.close().await.unwrap(); - } -} #[derive(Debug)] pub struct Gateway { @@ -349,12 +27,21 @@ pub struct Gateway { impl Gateway { #[allow(clippy::new_ret_no_self)] pub async fn new(websocket_url: String) -> Result { + let mut roots = rustls::RootCertStore::empty(); + for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") + { + roots.add(&rustls::Certificate(cert.0)).unwrap(); + } let (websocket_stream, _) = match connect_async_tls_with_config( &websocket_url, None, false, - Some(Connector::NativeTls( - TlsConnector::builder().build().unwrap(), + Some(Connector::Rustls( + rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(roots) + .with_no_client_auth() + .into(), )), ) .await @@ -701,10 +388,10 @@ impl Gateway { | GATEWAY_REQUEST_GUILD_MEMBERS | GATEWAY_CALL_SYNC | GATEWAY_LAZY_REQUEST => { - let error = GatewayError::UnexpectedOpcodeReceived { - opcode: gateway_payload.op_code, - }; - Err::<(), GatewayError>(error).unwrap(); + info!( + "Received unexpected opcode ({}) for current state. This might be due to a faulty server implementation and is likely not the fault of chorus.", + gateway_payload.op_code + ); } _ => { warn!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); @@ -728,196 +415,7 @@ impl Gateway { } } -/// Handles sending heartbeats to the gateway in another thread -#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used -#[derive(Debug)] -struct HeartbeatHandler { - /// How ofter heartbeats need to be sent at a minimum - pub heartbeat_interval: Duration, - /// The send channel for the heartbeat thread - pub send: Sender, - /// The handle of the thread - handle: JoinHandle<()>, -} - -impl HeartbeatHandler { - pub fn new( - heartbeat_interval: Duration, - websocket_tx: Arc< - Mutex< - SplitSink< - WebSocketStream>, - tokio_tungstenite::tungstenite::Message, - >, - >, - >, - kill_rc: tokio::sync::broadcast::Receiver<()>, - ) -> HeartbeatHandler { - let (send, receive) = tokio::sync::mpsc::channel(32); - let kill_receive = kill_rc.resubscribe(); - - let handle: JoinHandle<()> = task::spawn(async move { - HeartbeatHandler::heartbeat_task( - websocket_tx, - heartbeat_interval, - receive, - kill_receive, - ) - .await; - }); - - Self { - heartbeat_interval, - send, - handle, - } - } - - /// The main heartbeat task; - /// - /// Can be killed by the kill broadcast; - /// If the websocket is closed, will die out next time it tries to send a heartbeat; - pub async fn heartbeat_task( - websocket_tx: Arc< - Mutex< - SplitSink< - WebSocketStream>, - tokio_tungstenite::tungstenite::Message, - >, - >, - >, - heartbeat_interval: Duration, - mut receive: tokio::sync::mpsc::Receiver, - mut kill_receive: tokio::sync::broadcast::Receiver<()>, - ) { - let mut last_heartbeat_timestamp: Instant = time::Instant::now(); - let mut last_heartbeat_acknowledged = true; - let mut last_seq_number: Option = None; - - loop { - if kill_receive.try_recv().is_ok() { - trace!("GW: Closing heartbeat task"); - break; - } - - let timeout = if last_heartbeat_acknowledged { - heartbeat_interval - } else { - // If the server hasn't acknowledged our heartbeat we should resend it - Duration::from_millis(HEARTBEAT_ACK_TIMEOUT) - }; - - let mut should_send = false; - - tokio::select! { - () = sleep_until(last_heartbeat_timestamp + timeout) => { - should_send = true; - } - Some(communication) = receive.recv() => { - // If we received a seq number update, use that as the last seq number - if communication.sequence_number.is_some() { - last_seq_number = communication.sequence_number; - } - - if let Some(op_code) = communication.op_code { - match op_code { - GATEWAY_HEARTBEAT => { - // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately - should_send = true; - } - GATEWAY_HEARTBEAT_ACK => { - // The server received our heartbeat - last_heartbeat_acknowledged = true; - } - _ => {} - } - } - } - } - - if should_send { - trace!("GW: Sending Heartbeat.."); - - let heartbeat = types::GatewayHeartbeat { - op: GATEWAY_HEARTBEAT, - d: last_seq_number, - }; - - let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); - - let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); - - let send_result = websocket_tx.lock().await.send(msg).await; - if send_result.is_err() { - // We couldn't send, the websocket is broken - warn!("GW: Couldnt send heartbeat, websocket seems broken"); - break; - } - - last_heartbeat_timestamp = time::Instant::now(); - last_heartbeat_acknowledged = false; - } - } - } -} - -/// Used for communications between the heartbeat and gateway thread. -/// Either signifies a sequence number update, a heartbeat ACK or a Heartbeat request by the server -#[derive(Clone, Copy, Debug)] -struct HeartbeatThreadCommunication { - /// The opcode for the communication we received, if relevant - op_code: Option, - /// The sequence number we got from discord, if any - sequence_number: Option, -} - -/// Trait which defines the behavior of an Observer. An Observer is an object which is subscribed to -/// an Observable. The Observer is notified when the Observable's data changes. -/// In this case, the Observable is a [`GatewayEvent`], which is a wrapper around a WebSocketEvent. -/// Note that `Debug` is used to tell `Observer`s apart when unsubscribing. -#[async_trait] -pub trait Observer: Sync + Send + std::fmt::Debug { - async fn update(&self, data: &T); -} - -/// GatewayEvent is a wrapper around a WebSocketEvent. It is used to notify the observers of a -/// change in the WebSocketEvent. GatewayEvents are observable. -#[derive(Default, Debug)] -pub struct GatewayEvent { - observers: Vec>>, -} - -impl GatewayEvent { - /// Returns true if the GatewayEvent is observed by at least one Observer. - pub fn is_observed(&self) -> bool { - !self.observers.is_empty() - } - - /// Subscribes an Observer to the GatewayEvent. - pub fn subscribe(&mut self, observable: Arc>) { - self.observers.push(observable); - } - - /// Unsubscribes an Observer from the GatewayEvent. - pub fn unsubscribe(&mut self, observable: &dyn Observer) { - // .retain()'s closure retains only those elements of the vector, which have a different - // pointer value than observable. - // The usage of the debug format to compare the generic T of observers is quite stupid, but the only thing to compare between them is T and if T == T they are the same - // anddd there is no way to do that without using format - let to_remove = format!("{:?}", observable); - self.observers - .retain(|obs| format!("{:?}", obs) != to_remove); - } - - /// Notifies the observers of the GatewayEvent. - pub async fn notify(&self, new_event_data: T) { - for observer in &self.observers { - observer.update(&new_event_data).await; - } - } -} - -pub mod events { +pub mod event { use super::*; #[derive(Default, Debug)] @@ -1078,52 +576,3 @@ pub mod events { pub update: GatewayEvent, } } - -#[cfg(test)] -mod example { - use super::*; - use std::sync::atomic::{AtomicI32, Ordering::Relaxed}; - - #[derive(Debug)] - struct Consumer { - _name: String, - events_received: AtomicI32, - } - - #[async_trait] - impl Observer for Consumer { - async fn update(&self, _data: &types::GatewayResume) { - self.events_received.fetch_add(1, Relaxed); - } - } - - #[tokio::test] - async fn test_observer_behavior() { - let mut event = GatewayEvent::default(); - - let new_data = types::GatewayResume { - token: "token_3276ha37am3".to_string(), - session_id: "89346671230".to_string(), - seq: "3".to_string(), - }; - - let consumer = Arc::new(Consumer { - _name: "first".into(), - events_received: 0.into(), - }); - event.subscribe(consumer.clone()); - - let second_consumer = Arc::new(Consumer { - _name: "second".into(), - events_received: 0.into(), - }); - event.subscribe(second_consumer.clone()); - - event.notify(new_data.clone()).await; - event.unsubscribe(&*consumer); - event.notify(new_data).await; - - assert_eq!(consumer.events_received.load(Relaxed), 1); - assert_eq!(second_consumer.events_received.load(Relaxed), 2); - } -} diff --git a/src/gateway/handle.rs b/src/gateway/handle.rs new file mode 100644 index 0000000..1200b30 --- /dev/null +++ b/src/gateway/handle.rs @@ -0,0 +1,171 @@ +use super::{event::Events, *}; +use crate::types::{self, Composite}; + +/// Represents a handle to a Gateway connection. A Gateway connection will create observable +/// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently +/// implemented types with the trait [`WebSocketEvent`] +/// Using this handle you can also send Gateway Events directly. +#[derive(Debug, Clone)] +pub struct GatewayHandle { + pub url: String, + pub events: Arc>, + pub websocket_send: Arc< + Mutex< + SplitSink< + WebSocketStream>, + tokio_tungstenite::tungstenite::Message, + >, + >, + >, + /// Tells gateway tasks to close + pub(super) kill_send: tokio::sync::broadcast::Sender<()>, + pub(crate) store: Arc>>>>, +} + +impl GatewayHandle { + /// Sends json to the gateway with an opcode + async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value) { + let gateway_payload = types::GatewaySendPayload { + op_code, + event_data: Some(to_send), + sequence_number: None, + }; + + let payload_json = serde_json::to_string(&gateway_payload).unwrap(); + + let message = tokio_tungstenite::tungstenite::Message::text(payload_json); + + self.websocket_send + .lock() + .await + .send(message) + .await + .unwrap(); + } + + pub async fn observe>( + &self, + object: Arc>, + ) -> Arc> { + let mut store = self.store.lock().await; + let id = object.read().unwrap().id(); + if let Some(channel) = store.get(&id) { + let object = channel.clone(); + drop(store); + object + .read() + .unwrap() + .downcast_ref::() + .unwrap_or_else(|| { + panic!( + "Snowflake {} already exists in the store, but it is not of type T.", + id + ) + }); + let ptr = Arc::into_raw(object.clone()); + // SAFETY: + // - We have just checked that the typeid of the `dyn Any ...` matches that of `T`. + // - This operation doesn't read or write any shared data, and thus cannot cause a data race + // - The reference count is not being modified + let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock).clone() }; + let object = downcasted.read().unwrap().clone(); + + let watched_object = object.watch_whole(self).await; + *downcasted.write().unwrap() = watched_object; + downcasted + } else { + let id = object.read().unwrap().id(); + let object = object.read().unwrap().clone(); + let object = object.clone().watch_whole(self).await; + let wrapped = Arc::new(RwLock::new(object)); + store.insert(id, wrapped.clone()); + wrapped + } + } + + /// Recursively observes and updates all updateable fields on the struct T. Returns an object `T` + /// with all of its observable fields being observed. + pub async fn observe_and_into_inner>( + &self, + object: Arc>, + ) -> T { + let channel = self.observe(object.clone()).await; + let object = channel.read().unwrap().clone(); + object + } + + /// Sends an identify event to the gateway + pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Identify.."); + + self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await; + } + + /// Sends a resume event to the gateway + pub async fn send_resume(&self, to_send: types::GatewayResume) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Resume.."); + + self.send_json_event(GATEWAY_RESUME, to_send_value).await; + } + + /// Sends an update presence event to the gateway + pub async fn send_update_presence(&self, to_send: types::UpdatePresence) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Update Presence.."); + + self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value) + .await; + } + + /// Sends a request guild members to the server + pub async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Request Guild Members.."); + + self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value) + .await; + } + + /// Sends an update voice state to the server + pub async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) { + let to_send_value = serde_json::to_value(to_send).unwrap(); + + trace!("GW: Sending Update Voice State.."); + + self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value) + .await; + } + + /// Sends a call sync to the server + pub async fn send_call_sync(&self, to_send: types::CallSync) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Call Sync.."); + + self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await; + } + + /// Sends a Lazy Request + pub async fn send_lazy_request(&self, to_send: types::LazyRequest) { + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + trace!("GW: Sending Lazy Request.."); + + self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value) + .await; + } + + /// Closes the websocket connection and stops all gateway tasks; + /// + /// Esentially pulls the plug on the gateway, leaving it possible to resume; + pub async fn close(&self) { + self.kill_send.send(()).unwrap(); + self.websocket_send.lock().await.close().await.unwrap(); + } +} diff --git a/src/gateway/heartbeat.rs b/src/gateway/heartbeat.rs new file mode 100644 index 0000000..dd162b7 --- /dev/null +++ b/src/gateway/heartbeat.rs @@ -0,0 +1,149 @@ +use crate::types; + +use super::*; + +/// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms +const HEARTBEAT_ACK_TIMEOUT: u64 = 2000; + +/// Handles sending heartbeats to the gateway in another thread +#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used +#[derive(Debug)] +pub(super) struct HeartbeatHandler { + /// How ofter heartbeats need to be sent at a minimum + pub heartbeat_interval: Duration, + /// The send channel for the heartbeat thread + pub send: Sender, + /// The handle of the thread + handle: JoinHandle<()>, +} + +impl HeartbeatHandler { + pub fn new( + heartbeat_interval: Duration, + websocket_tx: Arc< + Mutex< + SplitSink< + WebSocketStream>, + tokio_tungstenite::tungstenite::Message, + >, + >, + >, + kill_rc: tokio::sync::broadcast::Receiver<()>, + ) -> HeartbeatHandler { + let (send, receive) = tokio::sync::mpsc::channel(32); + let kill_receive = kill_rc.resubscribe(); + + let handle: JoinHandle<()> = task::spawn(async move { + HeartbeatHandler::heartbeat_task( + websocket_tx, + heartbeat_interval, + receive, + kill_receive, + ) + .await; + }); + + Self { + heartbeat_interval, + send, + handle, + } + } + + /// The main heartbeat task; + /// + /// Can be killed by the kill broadcast; + /// If the websocket is closed, will die out next time it tries to send a heartbeat; + pub async fn heartbeat_task( + websocket_tx: Arc< + Mutex< + SplitSink< + WebSocketStream>, + tokio_tungstenite::tungstenite::Message, + >, + >, + >, + heartbeat_interval: Duration, + mut receive: tokio::sync::mpsc::Receiver, + mut kill_receive: tokio::sync::broadcast::Receiver<()>, + ) { + let mut last_heartbeat_timestamp: Instant = time::Instant::now(); + let mut last_heartbeat_acknowledged = true; + let mut last_seq_number: Option = None; + + loop { + if kill_receive.try_recv().is_ok() { + trace!("GW: Closing heartbeat task"); + break; + } + + let timeout = if last_heartbeat_acknowledged { + heartbeat_interval + } else { + // If the server hasn't acknowledged our heartbeat we should resend it + Duration::from_millis(HEARTBEAT_ACK_TIMEOUT) + }; + + let mut should_send = false; + + tokio::select! { + () = sleep_until(last_heartbeat_timestamp + timeout) => { + should_send = true; + } + Some(communication) = receive.recv() => { + // If we received a seq number update, use that as the last seq number + if communication.sequence_number.is_some() { + last_seq_number = communication.sequence_number; + } + + if let Some(op_code) = communication.op_code { + match op_code { + GATEWAY_HEARTBEAT => { + // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately + should_send = true; + } + GATEWAY_HEARTBEAT_ACK => { + // The server received our heartbeat + last_heartbeat_acknowledged = true; + } + _ => {} + } + } + } + } + + if should_send { + trace!("GW: Sending Heartbeat.."); + + let heartbeat = types::GatewayHeartbeat { + op: GATEWAY_HEARTBEAT, + d: last_seq_number, + }; + + let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); + + let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); + + let send_result = websocket_tx.lock().await.send(msg).await; + if send_result.is_err() { + // We couldn't send, the websocket is broken + warn!("GW: Couldnt send heartbeat, websocket seems broken"); + break; + } + + last_heartbeat_timestamp = time::Instant::now(); + last_heartbeat_acknowledged = false; + } + } + } +} + +/// Used for communications between the heartbeat and gateway thread. +/// Either signifies a sequence number update, a heartbeat ACK or a Heartbeat request by the server +#[derive(Clone, Copy, Debug)] +pub(super) struct HeartbeatThreadCommunication { + /// The opcode for the communication we received, if relevant + pub(super) op_code: Option, + /// The sequence number we got from discord, if any + pub(super) sequence_number: Option, +} diff --git a/src/gateway/message.rs b/src/gateway/message.rs new file mode 100644 index 0000000..edee9dd --- /dev/null +++ b/src/gateway/message.rs @@ -0,0 +1,73 @@ +use crate::types; + +use super::*; + +/// Represents a messsage received from the gateway. This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError]. +/// This struct is used internally when handling messages. +#[derive(Clone, Debug)] +pub struct GatewayMessage { + /// The message we received from the server + pub(super) message: tokio_tungstenite::tungstenite::Message, +} + +impl GatewayMessage { + /// Creates self from a tungstenite message + pub fn from_tungstenite_message(message: tokio_tungstenite::tungstenite::Message) -> Self { + Self { message } + } + + /// Parses the message as an error; + /// Returns the error if succesfully parsed, None if the message isn't an error + pub fn error(&self) -> Option { + let content = self.message.to_string(); + + // Some error strings have dots on the end, which we don't care about + let processed_content = content.to_lowercase().replace('.', ""); + + match processed_content.as_str() { + "unknown error" | "4000" => Some(GatewayError::Unknown), + "unknown opcode" | "4001" => Some(GatewayError::UnknownOpcode), + "decode error" | "error while decoding payload" | "4002" => Some(GatewayError::Decode), + "not authenticated" | "4003" => Some(GatewayError::NotAuthenticated), + "authentication failed" | "4004" => Some(GatewayError::AuthenticationFailed), + "already authenticated" | "4005" => Some(GatewayError::AlreadyAuthenticated), + "invalid seq" | "4007" => Some(GatewayError::InvalidSequenceNumber), + "rate limited" | "4008" => Some(GatewayError::RateLimited), + "session timed out" | "4009" => Some(GatewayError::SessionTimedOut), + "invalid shard" | "4010" => Some(GatewayError::InvalidShard), + "sharding required" | "4011" => Some(GatewayError::ShardingRequired), + "invalid api version" | "4012" => Some(GatewayError::InvalidAPIVersion), + "invalid intent(s)" | "invalid intent" | "4013" => Some(GatewayError::InvalidIntents), + "disallowed intent(s)" | "disallowed intents" | "4014" => { + Some(GatewayError::DisallowedIntents) + } + _ => None, + } + } + + /// Returns whether or not the message is an error + pub fn is_error(&self) -> bool { + self.error().is_some() + } + + /// Parses the message as a payload; + /// Returns a result of deserializing + pub fn payload(&self) -> Result { + return serde_json::from_str(self.message.to_text().unwrap()); + } + + /// Returns whether or not the message is a payload + pub fn is_payload(&self) -> bool { + // close messages are never payloads, payloads are only text messages + if self.message.is_close() | !self.message.is_text() { + return false; + } + + return self.payload().is_ok(); + } + + /// Returns whether or not the message is empty + pub fn is_empty(&self) -> bool { + self.message.is_empty() + } +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs new file mode 100644 index 0000000..ebd06cc --- /dev/null +++ b/src/gateway/mod.rs @@ -0,0 +1,187 @@ +pub mod gateway; +pub mod handle; +pub mod heartbeat; +pub mod message; + +pub use gateway::*; +pub use handle::*; +use heartbeat::*; +pub use message::*; + +use crate::errors::GatewayError; +use crate::types::{Snowflake, WebSocketEvent}; + +use async_trait::async_trait; +use std::any::Any; +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::{Arc, RwLock}; +use std::time::Duration; +use tokio::time::sleep_until; + +use futures_util::stream::SplitSink; +use futures_util::stream::SplitStream; +use futures_util::SinkExt; +use futures_util::StreamExt; +use log::{info, trace, warn}; +use tokio::net::TcpStream; +use tokio::sync::mpsc::Sender; +use tokio::sync::Mutex; +use tokio::task; +use tokio::task::JoinHandle; +use tokio::time; +use tokio::time::Instant; +use tokio_tungstenite::MaybeTlsStream; +use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream}; + +// Gateway opcodes +/// Opcode received when the server dispatches a [crate::types::WebSocketEvent] +const GATEWAY_DISPATCH: u8 = 0; +/// Opcode sent when sending a heartbeat +const GATEWAY_HEARTBEAT: u8 = 1; +/// Opcode sent to initiate a session +/// +/// See [types::GatewayIdentifyPayload] +const GATEWAY_IDENTIFY: u8 = 2; +/// Opcode sent to update our presence +/// +/// See [types::GatewayUpdatePresence] +const GATEWAY_UPDATE_PRESENCE: u8 = 3; +/// Opcode sent to update our state in vc +/// +/// Like muting, deafening, leaving, joining.. +/// +/// See [types::UpdateVoiceState] +const GATEWAY_UPDATE_VOICE_STATE: u8 = 4; +/// Opcode sent to resume a session +/// +/// See [types::GatewayResume] +const GATEWAY_RESUME: u8 = 6; +/// Opcode received to tell the client to reconnect +const GATEWAY_RECONNECT: u8 = 7; +/// Opcode sent to request guild member data +/// +/// See [types::GatewayRequestGuildMembers] +const GATEWAY_REQUEST_GUILD_MEMBERS: u8 = 8; +/// Opcode received to tell the client their token / session is invalid +const GATEWAY_INVALID_SESSION: u8 = 9; +/// Opcode received when initially connecting to the gateway, starts our heartbeat +/// +/// See [types::HelloData] +const GATEWAY_HELLO: u8 = 10; +/// Opcode received to acknowledge a heartbeat +const GATEWAY_HEARTBEAT_ACK: u8 = 11; +/// Opcode sent to get the voice state of users in a given DM/group channel +/// +/// See [types::CallSync] +const GATEWAY_CALL_SYNC: u8 = 13; +/// Opcode sent to get data for a server (Lazy Loading request) +/// +/// Sent by the official client when switching to a server +/// +/// See [types::LazyRequest] +const GATEWAY_LAZY_REQUEST: u8 = 14; + +pub type ObservableObject = dyn Send + Sync + Any; + +/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake. +pub trait Updateable: 'static + Send + Sync { + fn id(&self) -> Snowflake; +} + +/// Trait which defines the behavior of an Observer. An Observer is an object which is subscribed to +/// an Observable. The Observer is notified when the Observable's data changes. +/// In this case, the Observable is a [`GatewayEvent`], which is a wrapper around a WebSocketEvent. +/// Note that `Debug` is used to tell `Observer`s apart when unsubscribing. +#[async_trait] +pub trait Observer: Sync + Send + std::fmt::Debug { + async fn update(&self, data: &T); +} + +/// GatewayEvent is a wrapper around a WebSocketEvent. It is used to notify the observers of a +/// change in the WebSocketEvent. GatewayEvents are observable. +#[derive(Default, Debug)] +pub struct GatewayEvent { + observers: Vec>>, +} + +impl GatewayEvent { + /// Returns true if the GatewayEvent is observed by at least one Observer. + pub fn is_observed(&self) -> bool { + !self.observers.is_empty() + } + + /// Subscribes an Observer to the GatewayEvent. + pub fn subscribe(&mut self, observable: Arc>) { + self.observers.push(observable); + } + + /// Unsubscribes an Observer from the GatewayEvent. + pub fn unsubscribe(&mut self, observable: &dyn Observer) { + // .retain()'s closure retains only those elements of the vector, which have a different + // pointer value than observable. + // The usage of the debug format to compare the generic T of observers is quite stupid, but the only thing to compare between them is T and if T == T they are the same + // anddd there is no way to do that without using format + let to_remove = format!("{:?}", observable); + self.observers + .retain(|obs| format!("{:?}", obs) != to_remove); + } + + /// Notifies the observers of the GatewayEvent. + async fn notify(&self, new_event_data: T) { + for observer in &self.observers { + observer.update(&new_event_data).await; + } + } +} + +#[cfg(test)] +mod example { + use crate::types; + + use super::*; + use std::sync::atomic::{AtomicI32, Ordering::Relaxed}; + + #[derive(Debug)] + struct Consumer { + _name: String, + events_received: AtomicI32, + } + + #[async_trait] + impl Observer for Consumer { + async fn update(&self, _data: &types::GatewayResume) { + self.events_received.fetch_add(1, Relaxed); + } + } + + #[tokio::test] + async fn test_observer_behavior() { + let mut event = GatewayEvent::default(); + + let new_data = types::GatewayResume { + token: "token_3276ha37am3".to_string(), + session_id: "89346671230".to_string(), + seq: "3".to_string(), + }; + + let consumer = Arc::new(Consumer { + _name: "first".into(), + events_received: 0.into(), + }); + event.subscribe(consumer.clone()); + + let second_consumer = Arc::new(Consumer { + _name: "second".into(), + events_received: 0.into(), + }); + event.subscribe(second_consumer.clone()); + + event.notify(new_data.clone()).await; + event.unsubscribe(&*consumer); + event.notify(new_data).await; + + assert_eq!(consumer.events_received.load(Relaxed), 1); + assert_eq!(second_consumer.events_received.load(Relaxed), 2); + } +} diff --git a/src/types/config/types/security_configuration.rs b/src/types/config/types/security_configuration.rs index d025a4b..caeb72c 100644 --- a/src/types/config/types/security_configuration.rs +++ b/src/types/config/types/security_configuration.rs @@ -1,4 +1,5 @@ use base64::Engine; +use rand::Fill; use serde::{Deserialize, Serialize}; use crate::types::config::types::subconfigs::security::{ @@ -22,10 +23,15 @@ pub struct SecurityConfiguration { impl Default for SecurityConfiguration { fn default() -> Self { + let mut rng: rand::rngs::ThreadRng = rand::thread_rng(); let mut req_sig: [u8; 32] = [0; 32]; - let _ = openssl::rand::rand_bytes(&mut req_sig); let mut jwt_secret: [u8; 256] = [0; 256]; - let _ = openssl::rand::rand_bytes(&mut jwt_secret); + req_sig + .try_fill(&mut rng) + .expect("Unable to generate cryptographically safe secrets."); + jwt_secret + .try_fill(&mut rng) + .expect("Unable to generate cryptographically safe secrets."); Self { captcha: Default::default(), two_factor: Default::default(), diff --git a/src/types/entities/invite.rs b/src/types/entities/invite.rs index 842db3f..d08ad1d 100644 --- a/src/types/entities/invite.rs +++ b/src/types/entities/invite.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use crate::types::{Snowflake, WelcomeScreenObject}; use super::guild::GuildScheduledEvent; -use super::{Application, Channel, GuildMember, User}; +use super::{Application, Channel, GuildMember, NSFWLevel, User}; /// Represents a code that when used, adds a user to a guild or group DM channel, or creates a relationship between two users. /// See @@ -56,17 +56,6 @@ pub struct InviteGuild { pub welcome_screen: Option, } -/// See for an explanation on what -/// the levels mean. -#[derive(Debug, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum NSFWLevel { - Default = 0, - Explicit = 1, - Safe = 2, - AgeRestricted = 3, -} - /// See #[derive(Debug, Serialize, Deserialize)] pub struct InviteStageInstance {