Merge main

This commit is contained in:
kozabrada123 2023-11-16 09:59:09 +01:00
commit 5334a26ccd
15 changed files with 849 additions and 614 deletions

View File

@ -4,7 +4,7 @@ on:
push: push:
branches: [ "main", "dev" ] branches: [ "main", "dev" ]
pull_request: pull_request:
branches: [ "main" ] branches: [ "main", "dev" ]
env: env:
CARGO_TERM_COLOR: always CARGO_TERM_COLOR: always
@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Clone spacebar server - name: Clone spacebar server
run: | run: |
git clone https://github.com/bitfl0wer/server.git git clone https://github.com/bitfl0wer/server.git

View File

@ -14,7 +14,7 @@ on:
branches: [ "main", "preserve/*", "dev" ] branches: [ "main", "preserve/*", "dev" ]
pull_request: pull_request:
# The branches below must be a subset of the branches above # The branches below must be a subset of the branches above
branches: [ "main" ] branches: [ "main", "dev" ]
jobs: jobs:
rust-clippy-analyze: rust-clippy-analyze:
@ -26,10 +26,10 @@ jobs:
actions: read # only required for a private repository by github/codeql-action/upload-sarif to get the Action run status actions: read # only required for a private repository by github/codeql-action/upload-sarif to get the Action run status
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v4
- name: Install Rust toolchain - name: Install Rust toolchain
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af #@v1 uses: actions-rs/toolchain@v1
with: with:
profile: minimal profile: minimal
toolchain: stable toolchain: stable

105
Cargo.lock generated
View File

@ -178,7 +178,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "chorus" name = "chorus"
version = "0.9.0" version = "0.11.0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"base64 0.21.5", "base64 0.21.5",
@ -193,11 +193,12 @@ dependencies = [
"jsonwebtoken", "jsonwebtoken",
"lazy_static", "lazy_static",
"log", "log",
"native-tls",
"openssl",
"poem", "poem",
"rand",
"regex", "regex",
"reqwest", "reqwest",
"rustls",
"rustls-native-certs",
"rusty-hook", "rusty-hook",
"serde", "serde",
"serde-aux", "serde-aux",
@ -963,7 +964,7 @@ checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378"
dependencies = [ dependencies = [
"base64 0.21.5", "base64 0.21.5",
"pem", "pem",
"ring", "ring 0.16.20",
"serde", "serde",
"serde_json", "serde_json",
"simple_asn1", "simple_asn1",
@ -1584,11 +1585,25 @@ dependencies = [
"libc", "libc",
"once_cell", "once_cell",
"spin 0.5.2", "spin 0.5.2",
"untrusted", "untrusted 0.7.1",
"web-sys", "web-sys",
"winapi", "winapi",
] ]
[[package]]
name = "ring"
version = "0.17.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b"
dependencies = [
"cc",
"getrandom",
"libc",
"spin 0.9.8",
"untrusted 0.9.0",
"windows-sys",
]
[[package]] [[package]]
name = "rsa" name = "rsa"
version = "0.9.3" version = "0.9.3"
@ -1617,9 +1632,9 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "0.38.21" version = "0.38.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b426b0506e5d50a7d8dafcf2e81471400deb602392c7dd110815afb4eaf02a3" checksum = "80109a168d9bc0c7f483083244543a6eb0dba02295d33ca268145e6190d6df0c"
dependencies = [ dependencies = [
"bitflags 2.4.1", "bitflags 2.4.1",
"errno", "errno",
@ -1628,6 +1643,49 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "rustls"
version = "0.21.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "446e14c5cda4f3f30fe71863c34ec70f5ac79d6087097ad0bb433e1be5edf04c"
dependencies = [
"log",
"ring 0.17.5",
"rustls-webpki",
"sct",
]
[[package]]
name = "rustls-native-certs"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00"
dependencies = [
"openssl-probe",
"rustls-pemfile",
"schannel",
"security-framework",
]
[[package]]
name = "rustls-pemfile"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2"
dependencies = [
"base64 0.21.5",
]
[[package]]
name = "rustls-webpki"
version = "0.101.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
dependencies = [
"ring 0.17.5",
"untrusted 0.9.0",
]
[[package]] [[package]]
name = "rusty-hook" name = "rusty-hook"
version = "0.11.2" version = "0.11.2"
@ -1661,6 +1719,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "sct"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4"
dependencies = [
"ring 0.16.20",
"untrusted 0.7.1",
]
[[package]] [[package]]
name = "security-framework" name = "security-framework"
version = "2.9.2" version = "2.9.2"
@ -2287,6 +2355,16 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-rustls"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081"
dependencies = [
"rustls",
"tokio",
]
[[package]] [[package]]
name = "tokio-stream" name = "tokio-stream"
version = "0.1.14" version = "0.1.14"
@ -2306,9 +2384,10 @@ checksum = "212d5dcb2a1ce06d81107c3d0ffa3121fe974b73f068c8282cb1c32328113b6c"
dependencies = [ dependencies = [
"futures-util", "futures-util",
"log", "log",
"native-tls", "rustls",
"rustls-native-certs",
"tokio", "tokio",
"tokio-native-tls", "tokio-rustls",
"tungstenite", "tungstenite",
] ]
@ -2408,8 +2487,8 @@ dependencies = [
"http", "http",
"httparse", "httparse",
"log", "log",
"native-tls",
"rand", "rand",
"rustls",
"sha1", "sha1",
"thiserror", "thiserror",
"url", "url",
@ -2485,6 +2564,12 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "untrusted"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]] [[package]]
name = "url" name = "url"
version = "2.4.1" version = "2.4.1"

View File

@ -1,7 +1,7 @@
[package] [package]
name = "chorus" name = "chorus"
description = "A library for interacting with multiple Spacebar-compatible Instances at once." description = "A library for interacting with multiple Spacebar-compatible Instances at once."
version = "0.9.0" version = "0.11.0"
license = "AGPL-3.0" license = "AGPL-3.0"
edition = "2021" edition = "2021"
repository = "https://github.com/polyphony-chat/chorus" repository = "https://github.com/polyphony-chat/chorus"
@ -27,11 +27,12 @@ url = "2.4.0"
chrono = { version = "0.4.26", features = ["serde"] } chrono = { version = "0.4.26", features = ["serde"] }
regex = "1.9.4" regex = "1.9.4"
custom_error = "1.9.2" custom_error = "1.9.2"
native-tls = "0.2.11" tokio-tungstenite = { version = "0.20.1", features = [
tokio-tungstenite = { version = "0.20.0", features = ["native-tls"] } "rustls-tls-native-roots",
"rustls-native-certs",
] }
futures-util = "0.3.28" futures-util = "0.3.28"
http = "0.2.9" http = "0.2.9"
openssl = "0.10.56"
base64 = "0.21.3" base64 = "0.21.3"
hostname = "0.3.1" hostname = "0.3.1"
bitflags = { version = "2.4.0", features = ["serde"] } bitflags = { version = "2.4.0", features = ["serde"] }
@ -52,6 +53,9 @@ log = "0.4.20"
async-trait = "0.1.73" async-trait = "0.1.73"
chorus-macros = "0.2.0" chorus-macros = "0.2.0"
discortp = { version = "0.5.0", optional = true, features = ["rtp", "discord", "demux"] } discortp = { version = "0.5.0", optional = true, features = ["rtp", "discord", "demux"] }
rustls = "0.21.8"
rustls-native-certs = "0.6.3"
rand = "0.8.5"
[dev-dependencies] [dev-dependencies]
tokio = { version = "1.32.0", features = ["full"] } tokio = { version = "1.32.0", features = ["full"] }

View File

@ -4,8 +4,6 @@
[![Build][build-shield]][build-url] [![Build][build-shield]][build-url]
[![Coverage][coverage-shield]][coverage-url] [![Coverage][coverage-shield]][coverage-url]
[![Contributors][contributors-shield]][contributors-url] [![Contributors][contributors-shield]][contributors-url]
[![Forks][forks-shield]][forks-url]
[![Issues][issues-shield]][issues-url]
<img src="https://img.shields.io/static/v1?label=Status&message=Alpha&color=blue"> <img src="https://img.shields.io/static/v1?label=Status&message=Alpha&color=blue">
</br> </br>
@ -17,7 +15,6 @@
<h3 align="center">Chorus</h3> <h3 align="center">Chorus</h3>
<p align="center"> <p align="center">
A rust library for interacting with (multiple) Spacebar-compatible APIs and Gateways (at the same time).
<br /> <br />
<a href="https://github.com/polyphony-chat/chorus"><strong>Explore the docs »</strong></a> <a href="https://github.com/polyphony-chat/chorus"><strong>Explore the docs »</strong></a>
<br /> <br />
@ -32,16 +29,95 @@
</div> </div>
## About Chorus is a Rust library that allows developers to interact with multiple Spacebar-compatible APIs and Gateways (Including
Discord.com) simultaneously. The library provides a simple and efficient way to communicate with these services, making it easier for developers to build applications that rely on them. Chorus is open-source and welcomes contributions from the community.
Chorus is a Rust library that allows developers to interact with multiple Spacebar-compatible APIs and Gateways simultaneously. The library provides a simple and efficient way to communicate with these services, making it easier for developers to build applications that rely on them. Chorus is open-source and welcomes contributions from the community. ## A Tour of Chorus
Chorus combines all the required functionalities of a user-centric Spacebar library into one package. The library
handles a lot of things for you, such as rate limiting, authentication, and more. This means that you can focus on
building your application, instead of worrying about the underlying implementation details.
To get started with Chorus, import it into your project by adding the following to your `Cargo.toml` file:
```toml
[dependencies]
chorus = "0"
```
### Establishing a Connection
To connect to a Spacebar compatible server, you need to create an [`Instance`](https://docs.rs/chorus/latest/chorus/instance/struct.Instance.html) like this:
```rs
use chorus::instance::Instance;
use chorus::UrlBundle;
#[tokio::main]
async fn main() {
let bundle = UrlBundle::new(
"https://example.com/api".to_string(),
"wss://example.com/".to_string(),
"https://example.com/cdn".to_string(),
);
let instance = Instance::new(bundle, true)
.await
.expect("Failed to connect to the Spacebar server");
// You can create as many instances of `Instance` as you want, but each `Instance` should likely be unique.
dbg!(instance.instance_info);
dbg!(instance.limits_information);
}
```
This Instance can now be used to log in, register and from there on, interact with the server in all sorts of ways.
### Logging In
Logging in correctly provides you with an instance of `ChorusUser`, with which you can interact with the server and
manipulate the account. Assuming you already have an account on the server, you can log in like this:
```rs
use chorus::types::LoginSchema;
// Assume, you already have an account created on this instance. Registering an account works
// the same way, but you'd use the Register-specific Structs and methods instead.
let login_schema = LoginSchema {
login: "user@example.com".to_string(),
password: "Correct-Horse-Battery-Staple".to_string(),
..Default::default()
};
// Each user connects to the Gateway. The Gateway connection lives on a seperate thread. Depending on
// the runtime feature you choose, this can potentially take advantage of all of your computers' threads.
let user = instance
.login_account(login_schema)
.await
.expect("An error occurred during the login process");
dbg!(user.belongs_to);
dbg!(&user.object.read().unwrap().username);
```
## Supported Platforms
All major desktop operating systems (Windows, macOS (aarch64/x86_64), Linux (aarch64/x86_64)) are supported.
We are currently working on adding full support for `wasm32-unknown-unknown`. This will allow you to use Chorus in
your browser, or in any other environment that supports WebAssembly.
We recommend checking out the examples directory, as well as the documentation for more information.
## MSRV (Minimum Supported Rust Version)
Rust **1.67.1**. This number might change at any point while Chorus is not yet at version 1.0.0.
## Versioning
This crate uses Semantic Versioning 2.0.0 as its versioning scheme. You can read the specification [here](https://semver.org/spec/v2.0.0.html).
## Contributing ## Contributing
Chorus is currently missing voice support and a lot of API endpoints, many of which should be trivial to implement,
ever since [we streamlined the process of doing so](https://github.com/polyphony-chat/chorus/discussions/401).
If you'd like to contribute new functionality, check out [The 'Meta'-issues.](https://github.com/polyphony-chat/chorus/issues?q=is%3Aissue+label%3A%22Type%3A+Meta%22+) They contain a comprehensive list of all features which are yet missing for full Discord.com compatibility. If you'd like to contribute new functionality, check out [The 'Meta'-issues.](https://github.com/polyphony-chat/chorus/issues?q=is%3Aissue+label%3A%22Type%3A+Meta%22+) They contain a comprehensive list of all features which are yet missing for full Discord.com compatibility.
If you would like to contribute, please feel free to open an Issue with the idea you have, or a Please feel free to open an Issue with the idea you have, or a Pull Request. Please keep our [contribution guidelines](https://github.com/polyphony-chat/.github/blob/main/CONTRIBUTION_GUIDELINES.md) in mind. Your contribution might not be accepted if it violates these guidelines or [our Code of Conduct](https://github.com/polyphony-chat/.github/blob/main/CODE_OF_CONDUCT.md).
Pull Request. Please keep our [contribution guidelines](https://github.com/polyphony-chat/.github/blob/main/CONTRIBUTION_GUIDELINES.md) in mind. Your contribution might not be
accepted, if it violates these guidelines or [our Code of Conduct](https://github.com/polyphony-chat/.github/blob/main/CODE_OF_CONDUCT.md).
<details> <details>
<summary>Progress Tracker/Roadmap</summary> <summary>Progress Tracker/Roadmap</summary>

16
examples/instance.rs Normal file
View File

@ -0,0 +1,16 @@
use chorus::instance::Instance;
use chorus::UrlBundle;
#[tokio::main]
async fn main() {
let bundle = UrlBundle::new(
"https://example.com/api".to_string(),
"wss://example.com/".to_string(),
"https://example.com/cdn".to_string(),
);
let instance = Instance::new(bundle, true)
.await
.expect("Failed to connect to the Spacebar server");
dbg!(instance.instance_info);
dbg!(instance.limits_information);
}

30
examples/login.rs Normal file
View File

@ -0,0 +1,30 @@
use chorus::instance::Instance;
use chorus::types::LoginSchema;
use chorus::UrlBundle;
#[tokio::main]
async fn main() {
let bundle = UrlBundle::new(
"https://example.com/api".to_string(),
"wss://example.com/".to_string(),
"https://example.com/cdn".to_string(),
);
let instance = Instance::new(bundle, true)
.await
.expect("Failed to connect to the Spacebar server");
// Assume, you already have an account created on this instance. Registering an account works
// the same way, but you'd use the Register-specific Structs and methods instead.
let login_schema = LoginSchema {
login: "user@example.com".to_string(),
password: "Correct-Horse-Battery-Staple".to_string(),
..Default::default()
};
// Each user connects to the Gateway. The Gateway connection lives on a seperate thread. Depending on
// the runtime feature you choose, this can potentially take advantage of all of your computers' threads.
let user = instance
.login_account(login_schema)
.await
.expect("An error occurred during the login process");
dbg!(user.belongs_to);
dbg!(&user.object.read().unwrap().username);
}

View File

@ -28,11 +28,11 @@ impl ChorusUser {
.header("Authorization", self.token()), .header("Authorization", self.token()),
limit_type: LimitType::Global, limit_type: LimitType::Global,
}; };
if session_id.is_some() { if let Some(session_id) = session_id {
request.request = request request.request = request
.request .request
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.body(to_string(session_id.unwrap()).unwrap()); .body(to_string(session_id).unwrap());
} }
request.deserialize_response::<Invite>(self).await request.deserialize_response::<Invite>(self).await
} }

View File

@ -1,332 +1,10 @@
//! Gateway connection, communication and handling, as well as object caching and updating. use self::event::Events;
use super::*;
use crate::errors::GatewayError;
use crate::gateway::events::Events;
use crate::types::{ use crate::types::{
self, AutoModerationRule, AutoModerationRuleUpdate, Channel, ChannelCreate, ChannelDelete, self, AutoModerationRule, AutoModerationRuleUpdate, Channel, ChannelCreate, ChannelDelete,
ChannelUpdate, Composite, Guild, GuildRoleCreate, GuildRoleUpdate, JsonField, RoleObject, ChannelUpdate, Guild, GuildRoleCreate, GuildRoleUpdate, JsonField, RoleObject, SourceUrlField,
Snowflake, SourceUrlField, ThreadUpdate, UpdateMessage, WebSocketEvent, ThreadUpdate, UpdateMessage, WebSocketEvent,
}; };
use async_trait::async_trait;
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::time::sleep_until;
use futures_util::stream::SplitSink;
use futures_util::stream::SplitStream;
use futures_util::SinkExt;
use futures_util::StreamExt;
use log::{info, trace, warn};
use native_tls::TlsConnector;
use tokio::net::TcpStream;
use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex;
use tokio::task;
use tokio::task::JoinHandle;
use tokio::time;
use tokio::time::Instant;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream};
// Gateway opcodes
/// Opcode received when the server dispatches a [crate::types::WebSocketEvent]
const GATEWAY_DISPATCH: u8 = 0;
/// Opcode sent when sending a heartbeat
const GATEWAY_HEARTBEAT: u8 = 1;
/// Opcode sent to initiate a session
///
/// See [types::GatewayIdentifyPayload]
const GATEWAY_IDENTIFY: u8 = 2;
/// Opcode sent to update our presence
///
/// See [types::GatewayUpdatePresence]
const GATEWAY_UPDATE_PRESENCE: u8 = 3;
/// Opcode sent to update our state in vc
///
/// Like muting, deafening, leaving, joining..
///
/// See [types::UpdateVoiceState]
const GATEWAY_UPDATE_VOICE_STATE: u8 = 4;
/// Opcode sent to resume a session
///
/// See [types::GatewayResume]
const GATEWAY_RESUME: u8 = 6;
/// Opcode received to tell the client to reconnect
const GATEWAY_RECONNECT: u8 = 7;
/// Opcode sent to request guild member data
///
/// See [types::GatewayRequestGuildMembers]
const GATEWAY_REQUEST_GUILD_MEMBERS: u8 = 8;
/// Opcode received to tell the client their token / session is invalid
const GATEWAY_INVALID_SESSION: u8 = 9;
/// Opcode received when initially connecting to the gateway, starts our heartbeat
///
/// See [types::HelloData]
const GATEWAY_HELLO: u8 = 10;
/// Opcode received to acknowledge a heartbeat
const GATEWAY_HEARTBEAT_ACK: u8 = 11;
/// Opcode sent to get the voice state of users in a given DM/group channel
///
/// See [types::CallSync]
const GATEWAY_CALL_SYNC: u8 = 13;
/// Opcode sent to get data for a server (Lazy Loading request)
///
/// Sent by the official client when switching to a server
///
/// See [types::LazyRequest]
const GATEWAY_LAZY_REQUEST: u8 = 14;
/// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms
pub(crate) const HEARTBEAT_ACK_TIMEOUT: u64 = 2000;
/// Represents a messsage received from the gateway. This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError].
/// This struct is used internally when handling messages.
#[derive(Clone, Debug)]
pub struct GatewayMessage {
/// The message we received from the server
message: tokio_tungstenite::tungstenite::Message,
}
impl GatewayMessage {
/// Creates self from a tungstenite message
pub fn from_tungstenite_message(message: tokio_tungstenite::tungstenite::Message) -> Self {
Self { message }
}
/// Parses the message as an error;
/// Returns the error if succesfully parsed, None if the message isn't an error
pub fn error(&self) -> Option<GatewayError> {
let content = self.message.to_string();
// Some error strings have dots on the end, which we don't care about
let processed_content = content.to_lowercase().replace('.', "");
match processed_content.as_str() {
"unknown error" | "4000" => Some(GatewayError::Unknown),
"unknown opcode" | "4001" => Some(GatewayError::UnknownOpcode),
"decode error" | "error while decoding payload" | "4002" => Some(GatewayError::Decode),
"not authenticated" | "4003" => Some(GatewayError::NotAuthenticated),
"authentication failed" | "4004" => Some(GatewayError::AuthenticationFailed),
"already authenticated" | "4005" => Some(GatewayError::AlreadyAuthenticated),
"invalid seq" | "4007" => Some(GatewayError::InvalidSequenceNumber),
"rate limited" | "4008" => Some(GatewayError::RateLimited),
"session timed out" | "4009" => Some(GatewayError::SessionTimedOut),
"invalid shard" | "4010" => Some(GatewayError::InvalidShard),
"sharding required" | "4011" => Some(GatewayError::ShardingRequired),
"invalid api version" | "4012" => Some(GatewayError::InvalidAPIVersion),
"invalid intent(s)" | "invalid intent" | "4013" => Some(GatewayError::InvalidIntents),
"disallowed intent(s)" | "disallowed intents" | "4014" => {
Some(GatewayError::DisallowedIntents)
}
_ => None,
}
}
/// Returns whether or not the message is an error
pub fn is_error(&self) -> bool {
self.error().is_some()
}
/// Parses the message as a payload;
/// Returns a result of deserializing
pub fn payload(&self) -> Result<types::GatewayReceivePayload, serde_json::Error> {
return serde_json::from_str(self.message.to_text().unwrap());
}
/// Returns whether or not the message is a payload
pub fn is_payload(&self) -> bool {
// close messages are never payloads, payloads are only text messages
if self.message.is_close() | !self.message.is_text() {
return false;
}
return self.payload().is_ok();
}
/// Returns whether or not the message is empty
pub fn is_empty(&self) -> bool {
self.message.is_empty()
}
}
pub type ObservableObject = dyn Send + Sync + Any;
/// Represents a handle to a Gateway connection. A Gateway connection will create observable
/// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently
/// implemented types with the trait [`WebSocketEvent`]
/// Using this handle you can also send Gateway Events directly.
#[derive(Debug, Clone)]
pub struct GatewayHandle {
pub url: String,
pub events: Arc<Mutex<Events>>,
pub websocket_send: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
/// Tells gateway tasks to close
kill_send: tokio::sync::broadcast::Sender<()>,
pub(crate) store: Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>,
}
/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake.
pub trait Updateable: 'static + Send + Sync {
fn id(&self) -> Snowflake;
}
impl GatewayHandle {
/// Sends json to the gateway with an opcode
async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value) {
let gateway_payload = types::GatewaySendPayload {
op_code,
event_data: Some(to_send),
sequence_number: None,
};
let payload_json = serde_json::to_string(&gateway_payload).unwrap();
let message = tokio_tungstenite::tungstenite::Message::text(payload_json);
self.websocket_send
.lock()
.await
.send(message)
.await
.unwrap();
}
pub async fn observe<T: Updateable + Clone + Debug + Composite<T>>(
&self,
object: Arc<RwLock<T>>,
) -> Arc<RwLock<T>> {
let mut store = self.store.lock().await;
let id = object.read().unwrap().id();
if let Some(channel) = store.get(&id) {
let object = channel.clone();
drop(store);
object
.read()
.unwrap()
.downcast_ref::<T>()
.unwrap_or_else(|| {
panic!(
"Snowflake {} already exists in the store, but it is not of type T.",
id
)
});
let ptr = Arc::into_raw(object.clone());
// SAFETY:
// - We have just checked that the typeid of the `dyn Any ...` matches that of `T`.
// - This operation doesn't read or write any shared data, and thus cannot cause a data race
// - The reference count is not being modified
let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock<T>).clone() };
let object = downcasted.read().unwrap().clone();
let watched_object = object.watch_whole(self).await;
*downcasted.write().unwrap() = watched_object;
downcasted
} else {
let id = object.read().unwrap().id();
let object = object.read().unwrap().clone();
let object = object.clone().watch_whole(self).await;
let wrapped = Arc::new(RwLock::new(object));
store.insert(id, wrapped.clone());
wrapped
}
}
/// Recursively observes and updates all updateable fields on the struct T. Returns an object `T`
/// with all of its observable fields being observed.
pub async fn observe_and_into_inner<T: Updateable + Clone + Debug + Composite<T>>(
&self,
object: Arc<RwLock<T>>,
) -> T {
let channel = self.observe(object.clone()).await;
let object = channel.read().unwrap().clone();
object
}
/// Sends an identify event to the gateway
pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
trace!("GW: Sending Identify..");
self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await;
}
/// Sends a resume event to the gateway
pub async fn send_resume(&self, to_send: types::GatewayResume) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
trace!("GW: Sending Resume..");
self.send_json_event(GATEWAY_RESUME, to_send_value).await;
}
/// Sends an update presence event to the gateway
pub async fn send_update_presence(&self, to_send: types::UpdatePresence) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
trace!("GW: Sending Update Presence..");
self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value)
.await;
}
/// Sends a request guild members to the server
pub async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
trace!("GW: Sending Request Guild Members..");
self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value)
.await;
}
/// Sends an update voice state to the server
pub async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) {
let to_send_value = serde_json::to_value(to_send).unwrap();
trace!("GW: Sending Update Voice State..");
self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value)
.await;
}
/// Sends a call sync to the server
pub async fn send_call_sync(&self, to_send: types::CallSync) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
trace!("GW: Sending Call Sync..");
self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await;
}
/// Sends a Lazy Request
pub async fn send_lazy_request(&self, to_send: types::LazyRequest) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
trace!("GW: Sending Lazy Request..");
self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value)
.await;
}
/// Closes the websocket connection and stops all gateway tasks;
///
/// Esentially pulls the plug on the gateway, leaving it possible to resume;
pub async fn close(&self) {
self.kill_send.send(()).unwrap();
self.websocket_send.lock().await.close().await.unwrap();
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct Gateway { pub struct Gateway {
@ -349,12 +27,21 @@ pub struct Gateway {
impl Gateway { impl Gateway {
#[allow(clippy::new_ret_no_self)] #[allow(clippy::new_ret_no_self)]
pub async fn new(websocket_url: String) -> Result<GatewayHandle, GatewayError> { pub async fn new(websocket_url: String) -> Result<GatewayHandle, GatewayError> {
let mut roots = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
roots.add(&rustls::Certificate(cert.0)).unwrap();
}
let (websocket_stream, _) = match connect_async_tls_with_config( let (websocket_stream, _) = match connect_async_tls_with_config(
&websocket_url, &websocket_url,
None, None,
false, false,
Some(Connector::NativeTls( Some(Connector::Rustls(
TlsConnector::builder().build().unwrap(), rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth()
.into(),
)), )),
) )
.await .await
@ -701,10 +388,10 @@ impl Gateway {
| GATEWAY_REQUEST_GUILD_MEMBERS | GATEWAY_REQUEST_GUILD_MEMBERS
| GATEWAY_CALL_SYNC | GATEWAY_CALL_SYNC
| GATEWAY_LAZY_REQUEST => { | GATEWAY_LAZY_REQUEST => {
let error = GatewayError::UnexpectedOpcodeReceived { info!(
opcode: gateway_payload.op_code, "Received unexpected opcode ({}) for current state. This might be due to a faulty server implementation and is likely not the fault of chorus.",
}; gateway_payload.op_code
Err::<(), GatewayError>(error).unwrap(); );
} }
_ => { _ => {
warn!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); warn!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code);
@ -728,196 +415,7 @@ impl Gateway {
} }
} }
/// Handles sending heartbeats to the gateway in another thread pub mod event {
#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used
#[derive(Debug)]
struct HeartbeatHandler {
/// How ofter heartbeats need to be sent at a minimum
pub heartbeat_interval: Duration,
/// The send channel for the heartbeat thread
pub send: Sender<HeartbeatThreadCommunication>,
/// The handle of the thread
handle: JoinHandle<()>,
}
impl HeartbeatHandler {
pub fn new(
heartbeat_interval: Duration,
websocket_tx: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
kill_rc: tokio::sync::broadcast::Receiver<()>,
) -> HeartbeatHandler {
let (send, receive) = tokio::sync::mpsc::channel(32);
let kill_receive = kill_rc.resubscribe();
let handle: JoinHandle<()> = task::spawn(async move {
HeartbeatHandler::heartbeat_task(
websocket_tx,
heartbeat_interval,
receive,
kill_receive,
)
.await;
});
Self {
heartbeat_interval,
send,
handle,
}
}
/// The main heartbeat task;
///
/// Can be killed by the kill broadcast;
/// If the websocket is closed, will die out next time it tries to send a heartbeat;
pub async fn heartbeat_task(
websocket_tx: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
heartbeat_interval: Duration,
mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>,
mut kill_receive: tokio::sync::broadcast::Receiver<()>,
) {
let mut last_heartbeat_timestamp: Instant = time::Instant::now();
let mut last_heartbeat_acknowledged = true;
let mut last_seq_number: Option<u64> = None;
loop {
if kill_receive.try_recv().is_ok() {
trace!("GW: Closing heartbeat task");
break;
}
let timeout = if last_heartbeat_acknowledged {
heartbeat_interval
} else {
// If the server hasn't acknowledged our heartbeat we should resend it
Duration::from_millis(HEARTBEAT_ACK_TIMEOUT)
};
let mut should_send = false;
tokio::select! {
() = sleep_until(last_heartbeat_timestamp + timeout) => {
should_send = true;
}
Some(communication) = receive.recv() => {
// If we received a seq number update, use that as the last seq number
if communication.sequence_number.is_some() {
last_seq_number = communication.sequence_number;
}
if let Some(op_code) = communication.op_code {
match op_code {
GATEWAY_HEARTBEAT => {
// As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately
should_send = true;
}
GATEWAY_HEARTBEAT_ACK => {
// The server received our heartbeat
last_heartbeat_acknowledged = true;
}
_ => {}
}
}
}
}
if should_send {
trace!("GW: Sending Heartbeat..");
let heartbeat = types::GatewayHeartbeat {
op: GATEWAY_HEARTBEAT,
d: last_seq_number,
};
let heartbeat_json = serde_json::to_string(&heartbeat).unwrap();
let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
let send_result = websocket_tx.lock().await.send(msg).await;
if send_result.is_err() {
// We couldn't send, the websocket is broken
warn!("GW: Couldnt send heartbeat, websocket seems broken");
break;
}
last_heartbeat_timestamp = time::Instant::now();
last_heartbeat_acknowledged = false;
}
}
}
}
/// Used for communications between the heartbeat and gateway thread.
/// Either signifies a sequence number update, a heartbeat ACK or a Heartbeat request by the server
#[derive(Clone, Copy, Debug)]
struct HeartbeatThreadCommunication {
/// The opcode for the communication we received, if relevant
op_code: Option<u8>,
/// The sequence number we got from discord, if any
sequence_number: Option<u64>,
}
/// Trait which defines the behavior of an Observer. An Observer is an object which is subscribed to
/// an Observable. The Observer is notified when the Observable's data changes.
/// In this case, the Observable is a [`GatewayEvent`], which is a wrapper around a WebSocketEvent.
/// Note that `Debug` is used to tell `Observer`s apart when unsubscribing.
#[async_trait]
pub trait Observer<T>: Sync + Send + std::fmt::Debug {
async fn update(&self, data: &T);
}
/// GatewayEvent is a wrapper around a WebSocketEvent. It is used to notify the observers of a
/// change in the WebSocketEvent. GatewayEvents are observable.
#[derive(Default, Debug)]
pub struct GatewayEvent<T: WebSocketEvent> {
observers: Vec<Arc<dyn Observer<T>>>,
}
impl<T: WebSocketEvent> GatewayEvent<T> {
/// Returns true if the GatewayEvent is observed by at least one Observer.
pub fn is_observed(&self) -> bool {
!self.observers.is_empty()
}
/// Subscribes an Observer to the GatewayEvent.
pub fn subscribe(&mut self, observable: Arc<dyn Observer<T>>) {
self.observers.push(observable);
}
/// Unsubscribes an Observer from the GatewayEvent.
pub fn unsubscribe(&mut self, observable: &dyn Observer<T>) {
// .retain()'s closure retains only those elements of the vector, which have a different
// pointer value than observable.
// The usage of the debug format to compare the generic T of observers is quite stupid, but the only thing to compare between them is T and if T == T they are the same
// anddd there is no way to do that without using format
let to_remove = format!("{:?}", observable);
self.observers
.retain(|obs| format!("{:?}", obs) != to_remove);
}
/// Notifies the observers of the GatewayEvent.
pub async fn notify(&self, new_event_data: T) {
for observer in &self.observers {
observer.update(&new_event_data).await;
}
}
}
pub mod events {
use super::*; use super::*;
#[derive(Default, Debug)] #[derive(Default, Debug)]
@ -1078,52 +576,3 @@ pub mod events {
pub update: GatewayEvent<types::WebhooksUpdate>, pub update: GatewayEvent<types::WebhooksUpdate>,
} }
} }
#[cfg(test)]
mod example {
use super::*;
use std::sync::atomic::{AtomicI32, Ordering::Relaxed};
#[derive(Debug)]
struct Consumer {
_name: String,
events_received: AtomicI32,
}
#[async_trait]
impl Observer<types::GatewayResume> for Consumer {
async fn update(&self, _data: &types::GatewayResume) {
self.events_received.fetch_add(1, Relaxed);
}
}
#[tokio::test]
async fn test_observer_behavior() {
let mut event = GatewayEvent::default();
let new_data = types::GatewayResume {
token: "token_3276ha37am3".to_string(),
session_id: "89346671230".to_string(),
seq: "3".to_string(),
};
let consumer = Arc::new(Consumer {
_name: "first".into(),
events_received: 0.into(),
});
event.subscribe(consumer.clone());
let second_consumer = Arc::new(Consumer {
_name: "second".into(),
events_received: 0.into(),
});
event.subscribe(second_consumer.clone());
event.notify(new_data.clone()).await;
event.unsubscribe(&*consumer);
event.notify(new_data).await;
assert_eq!(consumer.events_received.load(Relaxed), 1);
assert_eq!(second_consumer.events_received.load(Relaxed), 2);
}
}

171
src/gateway/handle.rs Normal file
View File

@ -0,0 +1,171 @@
use super::{event::Events, *};
use crate::types::{self, Composite};
/// Represents a handle to a Gateway connection. A Gateway connection will create observable
/// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently
/// implemented types with the trait [`WebSocketEvent`]
/// Using this handle you can also send Gateway Events directly.
#[derive(Debug, Clone)]
pub struct GatewayHandle {
pub url: String,
pub events: Arc<Mutex<Events>>,
pub websocket_send: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
/// Tells gateway tasks to close
pub(super) kill_send: tokio::sync::broadcast::Sender<()>,
pub(crate) store: Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>,
}
impl GatewayHandle {
/// Sends json to the gateway with an opcode
async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value) {
let gateway_payload = types::GatewaySendPayload {
op_code,
event_data: Some(to_send),
sequence_number: None,
};
let payload_json = serde_json::to_string(&gateway_payload).unwrap();
let message = tokio_tungstenite::tungstenite::Message::text(payload_json);
self.websocket_send
.lock()
.await
.send(message)
.await
.unwrap();
}
pub async fn observe<T: Updateable + Clone + Debug + Composite<T>>(
&self,
object: Arc<RwLock<T>>,
) -> Arc<RwLock<T>> {
let mut store = self.store.lock().await;
let id = object.read().unwrap().id();
if let Some(channel) = store.get(&id) {
let object = channel.clone();
drop(store);
object
.read()
.unwrap()
.downcast_ref::<T>()
.unwrap_or_else(|| {
panic!(
"Snowflake {} already exists in the store, but it is not of type T.",
id
)
});
let ptr = Arc::into_raw(object.clone());
// SAFETY:
// - We have just checked that the typeid of the `dyn Any ...` matches that of `T`.
// - This operation doesn't read or write any shared data, and thus cannot cause a data race
// - The reference count is not being modified
let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock<T>).clone() };
let object = downcasted.read().unwrap().clone();
let watched_object = object.watch_whole(self).await;
*downcasted.write().unwrap() = watched_object;
downcasted
} else {
let id = object.read().unwrap().id();
let object = object.read().unwrap().clone();
let object = object.clone().watch_whole(self).await;
let wrapped = Arc::new(RwLock::new(object));
store.insert(id, wrapped.clone());
wrapped
}
}
/// Recursively observes and updates all updateable fields on the struct T. Returns an object `T`
/// with all of its observable fields being observed.
pub async fn observe_and_into_inner<T: Updateable + Clone + Debug + Composite<T>>(
&self,
object: Arc<RwLock<T>>,
) -> T {
let channel = self.observe(object.clone()).await;
let object = channel.read().unwrap().clone();
object
}
/// Sends an identify event to the gateway
pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
trace!("GW: Sending Identify..");
self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await;
}
/// Sends a resume event to the gateway
pub async fn send_resume(&self, to_send: types::GatewayResume) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
trace!("GW: Sending Resume..");
self.send_json_event(GATEWAY_RESUME, to_send_value).await;
}
/// Sends an update presence event to the gateway
pub async fn send_update_presence(&self, to_send: types::UpdatePresence) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
trace!("GW: Sending Update Presence..");
self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value)
.await;
}
/// Sends a request guild members to the server
pub async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
trace!("GW: Sending Request Guild Members..");
self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value)
.await;
}
/// Sends an update voice state to the server
pub async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) {
let to_send_value = serde_json::to_value(to_send).unwrap();
trace!("GW: Sending Update Voice State..");
self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value)
.await;
}
/// Sends a call sync to the server
pub async fn send_call_sync(&self, to_send: types::CallSync) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
trace!("GW: Sending Call Sync..");
self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await;
}
/// Sends a Lazy Request
pub async fn send_lazy_request(&self, to_send: types::LazyRequest) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
trace!("GW: Sending Lazy Request..");
self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value)
.await;
}
/// Closes the websocket connection and stops all gateway tasks;
///
/// Esentially pulls the plug on the gateway, leaving it possible to resume;
pub async fn close(&self) {
self.kill_send.send(()).unwrap();
self.websocket_send.lock().await.close().await.unwrap();
}
}

149
src/gateway/heartbeat.rs Normal file
View File

@ -0,0 +1,149 @@
use crate::types;
use super::*;
/// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms
const HEARTBEAT_ACK_TIMEOUT: u64 = 2000;
/// Handles sending heartbeats to the gateway in another thread
#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used
#[derive(Debug)]
pub(super) struct HeartbeatHandler {
/// How ofter heartbeats need to be sent at a minimum
pub heartbeat_interval: Duration,
/// The send channel for the heartbeat thread
pub send: Sender<HeartbeatThreadCommunication>,
/// The handle of the thread
handle: JoinHandle<()>,
}
impl HeartbeatHandler {
pub fn new(
heartbeat_interval: Duration,
websocket_tx: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
kill_rc: tokio::sync::broadcast::Receiver<()>,
) -> HeartbeatHandler {
let (send, receive) = tokio::sync::mpsc::channel(32);
let kill_receive = kill_rc.resubscribe();
let handle: JoinHandle<()> = task::spawn(async move {
HeartbeatHandler::heartbeat_task(
websocket_tx,
heartbeat_interval,
receive,
kill_receive,
)
.await;
});
Self {
heartbeat_interval,
send,
handle,
}
}
/// The main heartbeat task;
///
/// Can be killed by the kill broadcast;
/// If the websocket is closed, will die out next time it tries to send a heartbeat;
pub async fn heartbeat_task(
websocket_tx: Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
>,
>,
heartbeat_interval: Duration,
mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>,
mut kill_receive: tokio::sync::broadcast::Receiver<()>,
) {
let mut last_heartbeat_timestamp: Instant = time::Instant::now();
let mut last_heartbeat_acknowledged = true;
let mut last_seq_number: Option<u64> = None;
loop {
if kill_receive.try_recv().is_ok() {
trace!("GW: Closing heartbeat task");
break;
}
let timeout = if last_heartbeat_acknowledged {
heartbeat_interval
} else {
// If the server hasn't acknowledged our heartbeat we should resend it
Duration::from_millis(HEARTBEAT_ACK_TIMEOUT)
};
let mut should_send = false;
tokio::select! {
() = sleep_until(last_heartbeat_timestamp + timeout) => {
should_send = true;
}
Some(communication) = receive.recv() => {
// If we received a seq number update, use that as the last seq number
if communication.sequence_number.is_some() {
last_seq_number = communication.sequence_number;
}
if let Some(op_code) = communication.op_code {
match op_code {
GATEWAY_HEARTBEAT => {
// As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately
should_send = true;
}
GATEWAY_HEARTBEAT_ACK => {
// The server received our heartbeat
last_heartbeat_acknowledged = true;
}
_ => {}
}
}
}
}
if should_send {
trace!("GW: Sending Heartbeat..");
let heartbeat = types::GatewayHeartbeat {
op: GATEWAY_HEARTBEAT,
d: last_seq_number,
};
let heartbeat_json = serde_json::to_string(&heartbeat).unwrap();
let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
let send_result = websocket_tx.lock().await.send(msg).await;
if send_result.is_err() {
// We couldn't send, the websocket is broken
warn!("GW: Couldnt send heartbeat, websocket seems broken");
break;
}
last_heartbeat_timestamp = time::Instant::now();
last_heartbeat_acknowledged = false;
}
}
}
}
/// Used for communications between the heartbeat and gateway thread.
/// Either signifies a sequence number update, a heartbeat ACK or a Heartbeat request by the server
#[derive(Clone, Copy, Debug)]
pub(super) struct HeartbeatThreadCommunication {
/// The opcode for the communication we received, if relevant
pub(super) op_code: Option<u8>,
/// The sequence number we got from discord, if any
pub(super) sequence_number: Option<u64>,
}

73
src/gateway/message.rs Normal file
View File

@ -0,0 +1,73 @@
use crate::types;
use super::*;
/// Represents a messsage received from the gateway. This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError].
/// This struct is used internally when handling messages.
#[derive(Clone, Debug)]
pub struct GatewayMessage {
/// The message we received from the server
pub(super) message: tokio_tungstenite::tungstenite::Message,
}
impl GatewayMessage {
/// Creates self from a tungstenite message
pub fn from_tungstenite_message(message: tokio_tungstenite::tungstenite::Message) -> Self {
Self { message }
}
/// Parses the message as an error;
/// Returns the error if succesfully parsed, None if the message isn't an error
pub fn error(&self) -> Option<GatewayError> {
let content = self.message.to_string();
// Some error strings have dots on the end, which we don't care about
let processed_content = content.to_lowercase().replace('.', "");
match processed_content.as_str() {
"unknown error" | "4000" => Some(GatewayError::Unknown),
"unknown opcode" | "4001" => Some(GatewayError::UnknownOpcode),
"decode error" | "error while decoding payload" | "4002" => Some(GatewayError::Decode),
"not authenticated" | "4003" => Some(GatewayError::NotAuthenticated),
"authentication failed" | "4004" => Some(GatewayError::AuthenticationFailed),
"already authenticated" | "4005" => Some(GatewayError::AlreadyAuthenticated),
"invalid seq" | "4007" => Some(GatewayError::InvalidSequenceNumber),
"rate limited" | "4008" => Some(GatewayError::RateLimited),
"session timed out" | "4009" => Some(GatewayError::SessionTimedOut),
"invalid shard" | "4010" => Some(GatewayError::InvalidShard),
"sharding required" | "4011" => Some(GatewayError::ShardingRequired),
"invalid api version" | "4012" => Some(GatewayError::InvalidAPIVersion),
"invalid intent(s)" | "invalid intent" | "4013" => Some(GatewayError::InvalidIntents),
"disallowed intent(s)" | "disallowed intents" | "4014" => {
Some(GatewayError::DisallowedIntents)
}
_ => None,
}
}
/// Returns whether or not the message is an error
pub fn is_error(&self) -> bool {
self.error().is_some()
}
/// Parses the message as a payload;
/// Returns a result of deserializing
pub fn payload(&self) -> Result<types::GatewayReceivePayload, serde_json::Error> {
return serde_json::from_str(self.message.to_text().unwrap());
}
/// Returns whether or not the message is a payload
pub fn is_payload(&self) -> bool {
// close messages are never payloads, payloads are only text messages
if self.message.is_close() | !self.message.is_text() {
return false;
}
return self.payload().is_ok();
}
/// Returns whether or not the message is empty
pub fn is_empty(&self) -> bool {
self.message.is_empty()
}
}

187
src/gateway/mod.rs Normal file
View File

@ -0,0 +1,187 @@
pub mod gateway;
pub mod handle;
pub mod heartbeat;
pub mod message;
pub use gateway::*;
pub use handle::*;
use heartbeat::*;
pub use message::*;
use crate::errors::GatewayError;
use crate::types::{Snowflake, WebSocketEvent};
use async_trait::async_trait;
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::time::sleep_until;
use futures_util::stream::SplitSink;
use futures_util::stream::SplitStream;
use futures_util::SinkExt;
use futures_util::StreamExt;
use log::{info, trace, warn};
use tokio::net::TcpStream;
use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex;
use tokio::task;
use tokio::task::JoinHandle;
use tokio::time;
use tokio::time::Instant;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream};
// Gateway opcodes
/// Opcode received when the server dispatches a [crate::types::WebSocketEvent]
const GATEWAY_DISPATCH: u8 = 0;
/// Opcode sent when sending a heartbeat
const GATEWAY_HEARTBEAT: u8 = 1;
/// Opcode sent to initiate a session
///
/// See [types::GatewayIdentifyPayload]
const GATEWAY_IDENTIFY: u8 = 2;
/// Opcode sent to update our presence
///
/// See [types::GatewayUpdatePresence]
const GATEWAY_UPDATE_PRESENCE: u8 = 3;
/// Opcode sent to update our state in vc
///
/// Like muting, deafening, leaving, joining..
///
/// See [types::UpdateVoiceState]
const GATEWAY_UPDATE_VOICE_STATE: u8 = 4;
/// Opcode sent to resume a session
///
/// See [types::GatewayResume]
const GATEWAY_RESUME: u8 = 6;
/// Opcode received to tell the client to reconnect
const GATEWAY_RECONNECT: u8 = 7;
/// Opcode sent to request guild member data
///
/// See [types::GatewayRequestGuildMembers]
const GATEWAY_REQUEST_GUILD_MEMBERS: u8 = 8;
/// Opcode received to tell the client their token / session is invalid
const GATEWAY_INVALID_SESSION: u8 = 9;
/// Opcode received when initially connecting to the gateway, starts our heartbeat
///
/// See [types::HelloData]
const GATEWAY_HELLO: u8 = 10;
/// Opcode received to acknowledge a heartbeat
const GATEWAY_HEARTBEAT_ACK: u8 = 11;
/// Opcode sent to get the voice state of users in a given DM/group channel
///
/// See [types::CallSync]
const GATEWAY_CALL_SYNC: u8 = 13;
/// Opcode sent to get data for a server (Lazy Loading request)
///
/// Sent by the official client when switching to a server
///
/// See [types::LazyRequest]
const GATEWAY_LAZY_REQUEST: u8 = 14;
pub type ObservableObject = dyn Send + Sync + Any;
/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake.
pub trait Updateable: 'static + Send + Sync {
fn id(&self) -> Snowflake;
}
/// Trait which defines the behavior of an Observer. An Observer is an object which is subscribed to
/// an Observable. The Observer is notified when the Observable's data changes.
/// In this case, the Observable is a [`GatewayEvent`], which is a wrapper around a WebSocketEvent.
/// Note that `Debug` is used to tell `Observer`s apart when unsubscribing.
#[async_trait]
pub trait Observer<T>: Sync + Send + std::fmt::Debug {
async fn update(&self, data: &T);
}
/// GatewayEvent is a wrapper around a WebSocketEvent. It is used to notify the observers of a
/// change in the WebSocketEvent. GatewayEvents are observable.
#[derive(Default, Debug)]
pub struct GatewayEvent<T: WebSocketEvent> {
observers: Vec<Arc<dyn Observer<T>>>,
}
impl<T: WebSocketEvent> GatewayEvent<T> {
/// Returns true if the GatewayEvent is observed by at least one Observer.
pub fn is_observed(&self) -> bool {
!self.observers.is_empty()
}
/// Subscribes an Observer to the GatewayEvent.
pub fn subscribe(&mut self, observable: Arc<dyn Observer<T>>) {
self.observers.push(observable);
}
/// Unsubscribes an Observer from the GatewayEvent.
pub fn unsubscribe(&mut self, observable: &dyn Observer<T>) {
// .retain()'s closure retains only those elements of the vector, which have a different
// pointer value than observable.
// The usage of the debug format to compare the generic T of observers is quite stupid, but the only thing to compare between them is T and if T == T they are the same
// anddd there is no way to do that without using format
let to_remove = format!("{:?}", observable);
self.observers
.retain(|obs| format!("{:?}", obs) != to_remove);
}
/// Notifies the observers of the GatewayEvent.
async fn notify(&self, new_event_data: T) {
for observer in &self.observers {
observer.update(&new_event_data).await;
}
}
}
#[cfg(test)]
mod example {
use crate::types;
use super::*;
use std::sync::atomic::{AtomicI32, Ordering::Relaxed};
#[derive(Debug)]
struct Consumer {
_name: String,
events_received: AtomicI32,
}
#[async_trait]
impl Observer<types::GatewayResume> for Consumer {
async fn update(&self, _data: &types::GatewayResume) {
self.events_received.fetch_add(1, Relaxed);
}
}
#[tokio::test]
async fn test_observer_behavior() {
let mut event = GatewayEvent::default();
let new_data = types::GatewayResume {
token: "token_3276ha37am3".to_string(),
session_id: "89346671230".to_string(),
seq: "3".to_string(),
};
let consumer = Arc::new(Consumer {
_name: "first".into(),
events_received: 0.into(),
});
event.subscribe(consumer.clone());
let second_consumer = Arc::new(Consumer {
_name: "second".into(),
events_received: 0.into(),
});
event.subscribe(second_consumer.clone());
event.notify(new_data.clone()).await;
event.unsubscribe(&*consumer);
event.notify(new_data).await;
assert_eq!(consumer.events_received.load(Relaxed), 1);
assert_eq!(second_consumer.events_received.load(Relaxed), 2);
}
}

View File

@ -1,4 +1,5 @@
use base64::Engine; use base64::Engine;
use rand::Fill;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::types::config::types::subconfigs::security::{ use crate::types::config::types::subconfigs::security::{
@ -22,10 +23,15 @@ pub struct SecurityConfiguration {
impl Default for SecurityConfiguration { impl Default for SecurityConfiguration {
fn default() -> Self { fn default() -> Self {
let mut rng: rand::rngs::ThreadRng = rand::thread_rng();
let mut req_sig: [u8; 32] = [0; 32]; let mut req_sig: [u8; 32] = [0; 32];
let _ = openssl::rand::rand_bytes(&mut req_sig);
let mut jwt_secret: [u8; 256] = [0; 256]; let mut jwt_secret: [u8; 256] = [0; 256];
let _ = openssl::rand::rand_bytes(&mut jwt_secret); req_sig
.try_fill(&mut rng)
.expect("Unable to generate cryptographically safe secrets.");
jwt_secret
.try_fill(&mut rng)
.expect("Unable to generate cryptographically safe secrets.");
Self { Self {
captcha: Default::default(), captcha: Default::default(),
two_factor: Default::default(), two_factor: Default::default(),

View File

@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};
use crate::types::{Snowflake, WelcomeScreenObject}; use crate::types::{Snowflake, WelcomeScreenObject};
use super::guild::GuildScheduledEvent; use super::guild::GuildScheduledEvent;
use super::{Application, Channel, GuildMember, User}; use super::{Application, Channel, GuildMember, NSFWLevel, User};
/// Represents a code that when used, adds a user to a guild or group DM channel, or creates a relationship between two users. /// Represents a code that when used, adds a user to a guild or group DM channel, or creates a relationship between two users.
/// See <https://discord-userdoccers.vercel.app/resources/invite#invite-object> /// See <https://discord-userdoccers.vercel.app/resources/invite#invite-object>
@ -56,17 +56,6 @@ pub struct InviteGuild {
pub welcome_screen: Option<WelcomeScreenObject>, pub welcome_screen: Option<WelcomeScreenObject>,
} }
/// See <https://discord-userdoccers.vercel.app/resources/guild#nsfw-level> for an explanation on what
/// the levels mean.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum NSFWLevel {
Default = 0,
Explicit = 1,
Safe = 2,
AgeRestricted = 3,
}
/// See <https://discord-userdoccers.vercel.app/resources/invite#invite-stage-instance-object> /// See <https://discord-userdoccers.vercel.app/resources/invite#invite-stage-instance-object>
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct InviteStageInstance { pub struct InviteStageInstance {