Split up gateway, fix a few clippy warnings (#436)
As the title suggests, `Gateway` has been split up into several files, due to it being several hundred locs in length and, in my opinion, becoming difficult/frustrating to navigate. I have also fixed 2 clippy warnings: I have refactored an `if x.is_some()... { x.unwrap() }` into the more idiomatic `if let Some(y) = x ...`, and I have removed the duplicate definition of `NSFWLevel` in `types/invite`.
This commit is contained in:
commit
08b663c114
|
@ -28,11 +28,11 @@ impl ChorusUser {
|
||||||
.header("Authorization", self.token()),
|
.header("Authorization", self.token()),
|
||||||
limit_type: LimitType::Global,
|
limit_type: LimitType::Global,
|
||||||
};
|
};
|
||||||
if session_id.is_some() {
|
if let Some(session_id) = session_id {
|
||||||
request.request = request
|
request.request = request
|
||||||
.request
|
.request
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.body(to_string(session_id.unwrap()).unwrap());
|
.body(to_string(session_id).unwrap());
|
||||||
}
|
}
|
||||||
request.deserialize_response::<Invite>(self).await
|
request.deserialize_response::<Invite>(self).await
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,331 +1,10 @@
|
||||||
//! Gateway connection, communication and handling, as well as object caching and updating.
|
use self::event::Events;
|
||||||
|
use super::*;
|
||||||
use crate::errors::GatewayError;
|
|
||||||
use crate::gateway::events::Events;
|
|
||||||
use crate::types::{
|
use crate::types::{
|
||||||
self, AutoModerationRule, AutoModerationRuleUpdate, Channel, ChannelCreate, ChannelDelete,
|
self, AutoModerationRule, AutoModerationRuleUpdate, Channel, ChannelCreate, ChannelDelete,
|
||||||
ChannelUpdate, Composite, Guild, GuildRoleCreate, GuildRoleUpdate, JsonField, RoleObject,
|
ChannelUpdate, Guild, GuildRoleCreate, GuildRoleUpdate, JsonField, RoleObject, SourceUrlField,
|
||||||
Snowflake, SourceUrlField, ThreadUpdate, UpdateMessage, WebSocketEvent,
|
ThreadUpdate, UpdateMessage, WebSocketEvent,
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
|
||||||
use std::any::Any;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::fmt::Debug;
|
|
||||||
use std::sync::{Arc, RwLock};
|
|
||||||
use std::time::Duration;
|
|
||||||
use tokio::time::sleep_until;
|
|
||||||
|
|
||||||
use futures_util::stream::SplitSink;
|
|
||||||
use futures_util::stream::SplitStream;
|
|
||||||
use futures_util::SinkExt;
|
|
||||||
use futures_util::StreamExt;
|
|
||||||
use log::{info, trace, warn};
|
|
||||||
use tokio::net::TcpStream;
|
|
||||||
use tokio::sync::mpsc::Sender;
|
|
||||||
use tokio::sync::Mutex;
|
|
||||||
use tokio::task;
|
|
||||||
use tokio::task::JoinHandle;
|
|
||||||
use tokio::time;
|
|
||||||
use tokio::time::Instant;
|
|
||||||
use tokio_tungstenite::MaybeTlsStream;
|
|
||||||
use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream};
|
|
||||||
|
|
||||||
// Gateway opcodes
|
|
||||||
/// Opcode received when the server dispatches a [crate::types::WebSocketEvent]
|
|
||||||
const GATEWAY_DISPATCH: u8 = 0;
|
|
||||||
/// Opcode sent when sending a heartbeat
|
|
||||||
const GATEWAY_HEARTBEAT: u8 = 1;
|
|
||||||
/// Opcode sent to initiate a session
|
|
||||||
///
|
|
||||||
/// See [types::GatewayIdentifyPayload]
|
|
||||||
const GATEWAY_IDENTIFY: u8 = 2;
|
|
||||||
/// Opcode sent to update our presence
|
|
||||||
///
|
|
||||||
/// See [types::GatewayUpdatePresence]
|
|
||||||
const GATEWAY_UPDATE_PRESENCE: u8 = 3;
|
|
||||||
/// Opcode sent to update our state in vc
|
|
||||||
///
|
|
||||||
/// Like muting, deafening, leaving, joining..
|
|
||||||
///
|
|
||||||
/// See [types::UpdateVoiceState]
|
|
||||||
const GATEWAY_UPDATE_VOICE_STATE: u8 = 4;
|
|
||||||
/// Opcode sent to resume a session
|
|
||||||
///
|
|
||||||
/// See [types::GatewayResume]
|
|
||||||
const GATEWAY_RESUME: u8 = 6;
|
|
||||||
/// Opcode received to tell the client to reconnect
|
|
||||||
const GATEWAY_RECONNECT: u8 = 7;
|
|
||||||
/// Opcode sent to request guild member data
|
|
||||||
///
|
|
||||||
/// See [types::GatewayRequestGuildMembers]
|
|
||||||
const GATEWAY_REQUEST_GUILD_MEMBERS: u8 = 8;
|
|
||||||
/// Opcode received to tell the client their token / session is invalid
|
|
||||||
const GATEWAY_INVALID_SESSION: u8 = 9;
|
|
||||||
/// Opcode received when initially connecting to the gateway, starts our heartbeat
|
|
||||||
///
|
|
||||||
/// See [types::HelloData]
|
|
||||||
const GATEWAY_HELLO: u8 = 10;
|
|
||||||
/// Opcode received to acknowledge a heartbeat
|
|
||||||
const GATEWAY_HEARTBEAT_ACK: u8 = 11;
|
|
||||||
/// Opcode sent to get the voice state of users in a given DM/group channel
|
|
||||||
///
|
|
||||||
/// See [types::CallSync]
|
|
||||||
const GATEWAY_CALL_SYNC: u8 = 13;
|
|
||||||
/// Opcode sent to get data for a server (Lazy Loading request)
|
|
||||||
///
|
|
||||||
/// Sent by the official client when switching to a server
|
|
||||||
///
|
|
||||||
/// See [types::LazyRequest]
|
|
||||||
const GATEWAY_LAZY_REQUEST: u8 = 14;
|
|
||||||
|
|
||||||
/// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms
|
|
||||||
const HEARTBEAT_ACK_TIMEOUT: u64 = 2000;
|
|
||||||
|
|
||||||
/// Represents a messsage received from the gateway. This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError].
|
|
||||||
/// This struct is used internally when handling messages.
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct GatewayMessage {
|
|
||||||
/// The message we received from the server
|
|
||||||
message: tokio_tungstenite::tungstenite::Message,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GatewayMessage {
|
|
||||||
/// Creates self from a tungstenite message
|
|
||||||
pub fn from_tungstenite_message(message: tokio_tungstenite::tungstenite::Message) -> Self {
|
|
||||||
Self { message }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parses the message as an error;
|
|
||||||
/// Returns the error if succesfully parsed, None if the message isn't an error
|
|
||||||
pub fn error(&self) -> Option<GatewayError> {
|
|
||||||
let content = self.message.to_string();
|
|
||||||
|
|
||||||
// Some error strings have dots on the end, which we don't care about
|
|
||||||
let processed_content = content.to_lowercase().replace('.', "");
|
|
||||||
|
|
||||||
match processed_content.as_str() {
|
|
||||||
"unknown error" | "4000" => Some(GatewayError::Unknown),
|
|
||||||
"unknown opcode" | "4001" => Some(GatewayError::UnknownOpcode),
|
|
||||||
"decode error" | "error while decoding payload" | "4002" => Some(GatewayError::Decode),
|
|
||||||
"not authenticated" | "4003" => Some(GatewayError::NotAuthenticated),
|
|
||||||
"authentication failed" | "4004" => Some(GatewayError::AuthenticationFailed),
|
|
||||||
"already authenticated" | "4005" => Some(GatewayError::AlreadyAuthenticated),
|
|
||||||
"invalid seq" | "4007" => Some(GatewayError::InvalidSequenceNumber),
|
|
||||||
"rate limited" | "4008" => Some(GatewayError::RateLimited),
|
|
||||||
"session timed out" | "4009" => Some(GatewayError::SessionTimedOut),
|
|
||||||
"invalid shard" | "4010" => Some(GatewayError::InvalidShard),
|
|
||||||
"sharding required" | "4011" => Some(GatewayError::ShardingRequired),
|
|
||||||
"invalid api version" | "4012" => Some(GatewayError::InvalidAPIVersion),
|
|
||||||
"invalid intent(s)" | "invalid intent" | "4013" => Some(GatewayError::InvalidIntents),
|
|
||||||
"disallowed intent(s)" | "disallowed intents" | "4014" => {
|
|
||||||
Some(GatewayError::DisallowedIntents)
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns whether or not the message is an error
|
|
||||||
pub fn is_error(&self) -> bool {
|
|
||||||
self.error().is_some()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parses the message as a payload;
|
|
||||||
/// Returns a result of deserializing
|
|
||||||
pub fn payload(&self) -> Result<types::GatewayReceivePayload, serde_json::Error> {
|
|
||||||
return serde_json::from_str(self.message.to_text().unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns whether or not the message is a payload
|
|
||||||
pub fn is_payload(&self) -> bool {
|
|
||||||
// close messages are never payloads, payloads are only text messages
|
|
||||||
if self.message.is_close() | !self.message.is_text() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return self.payload().is_ok();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns whether or not the message is empty
|
|
||||||
pub fn is_empty(&self) -> bool {
|
|
||||||
self.message.is_empty()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type ObservableObject = dyn Send + Sync + Any;
|
|
||||||
|
|
||||||
/// Represents a handle to a Gateway connection. A Gateway connection will create observable
|
|
||||||
/// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently
|
|
||||||
/// implemented types with the trait [`WebSocketEvent`]
|
|
||||||
/// Using this handle you can also send Gateway Events directly.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct GatewayHandle {
|
|
||||||
pub url: String,
|
|
||||||
pub events: Arc<Mutex<Events>>,
|
|
||||||
pub websocket_send: Arc<
|
|
||||||
Mutex<
|
|
||||||
SplitSink<
|
|
||||||
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
|
||||||
tokio_tungstenite::tungstenite::Message,
|
|
||||||
>,
|
|
||||||
>,
|
|
||||||
>,
|
|
||||||
/// Tells gateway tasks to close
|
|
||||||
kill_send: tokio::sync::broadcast::Sender<()>,
|
|
||||||
pub(crate) store: Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake.
|
|
||||||
pub trait Updateable: 'static + Send + Sync {
|
|
||||||
fn id(&self) -> Snowflake;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GatewayHandle {
|
|
||||||
/// Sends json to the gateway with an opcode
|
|
||||||
async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value) {
|
|
||||||
let gateway_payload = types::GatewaySendPayload {
|
|
||||||
op_code,
|
|
||||||
event_data: Some(to_send),
|
|
||||||
sequence_number: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let payload_json = serde_json::to_string(&gateway_payload).unwrap();
|
|
||||||
|
|
||||||
let message = tokio_tungstenite::tungstenite::Message::text(payload_json);
|
|
||||||
|
|
||||||
self.websocket_send
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.send(message)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn observe<T: Updateable + Clone + Debug + Composite<T>>(
|
|
||||||
&self,
|
|
||||||
object: Arc<RwLock<T>>,
|
|
||||||
) -> Arc<RwLock<T>> {
|
|
||||||
let mut store = self.store.lock().await;
|
|
||||||
let id = object.read().unwrap().id();
|
|
||||||
if let Some(channel) = store.get(&id) {
|
|
||||||
let object = channel.clone();
|
|
||||||
drop(store);
|
|
||||||
object
|
|
||||||
.read()
|
|
||||||
.unwrap()
|
|
||||||
.downcast_ref::<T>()
|
|
||||||
.unwrap_or_else(|| {
|
|
||||||
panic!(
|
|
||||||
"Snowflake {} already exists in the store, but it is not of type T.",
|
|
||||||
id
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let ptr = Arc::into_raw(object.clone());
|
|
||||||
// SAFETY:
|
|
||||||
// - We have just checked that the typeid of the `dyn Any ...` matches that of `T`.
|
|
||||||
// - This operation doesn't read or write any shared data, and thus cannot cause a data race
|
|
||||||
// - The reference count is not being modified
|
|
||||||
let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock<T>).clone() };
|
|
||||||
let object = downcasted.read().unwrap().clone();
|
|
||||||
|
|
||||||
let watched_object = object.watch_whole(self).await;
|
|
||||||
*downcasted.write().unwrap() = watched_object;
|
|
||||||
downcasted
|
|
||||||
} else {
|
|
||||||
let id = object.read().unwrap().id();
|
|
||||||
let object = object.read().unwrap().clone();
|
|
||||||
let object = object.clone().watch_whole(self).await;
|
|
||||||
let wrapped = Arc::new(RwLock::new(object));
|
|
||||||
store.insert(id, wrapped.clone());
|
|
||||||
wrapped
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Recursively observes and updates all updateable fields on the struct T. Returns an object `T`
|
|
||||||
/// with all of its observable fields being observed.
|
|
||||||
pub async fn observe_and_into_inner<T: Updateable + Clone + Debug + Composite<T>>(
|
|
||||||
&self,
|
|
||||||
object: Arc<RwLock<T>>,
|
|
||||||
) -> T {
|
|
||||||
let channel = self.observe(object.clone()).await;
|
|
||||||
let object = channel.read().unwrap().clone();
|
|
||||||
object
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends an identify event to the gateway
|
|
||||||
pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) {
|
|
||||||
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
|
||||||
|
|
||||||
trace!("GW: Sending Identify..");
|
|
||||||
|
|
||||||
self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends a resume event to the gateway
|
|
||||||
pub async fn send_resume(&self, to_send: types::GatewayResume) {
|
|
||||||
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
|
||||||
|
|
||||||
trace!("GW: Sending Resume..");
|
|
||||||
|
|
||||||
self.send_json_event(GATEWAY_RESUME, to_send_value).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends an update presence event to the gateway
|
|
||||||
pub async fn send_update_presence(&self, to_send: types::UpdatePresence) {
|
|
||||||
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
|
||||||
|
|
||||||
trace!("GW: Sending Update Presence..");
|
|
||||||
|
|
||||||
self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value)
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends a request guild members to the server
|
|
||||||
pub async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) {
|
|
||||||
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
|
||||||
|
|
||||||
trace!("GW: Sending Request Guild Members..");
|
|
||||||
|
|
||||||
self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value)
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends an update voice state to the server
|
|
||||||
pub async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) {
|
|
||||||
let to_send_value = serde_json::to_value(to_send).unwrap();
|
|
||||||
|
|
||||||
trace!("GW: Sending Update Voice State..");
|
|
||||||
|
|
||||||
self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value)
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends a call sync to the server
|
|
||||||
pub async fn send_call_sync(&self, to_send: types::CallSync) {
|
|
||||||
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
|
||||||
|
|
||||||
trace!("GW: Sending Call Sync..");
|
|
||||||
|
|
||||||
self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends a Lazy Request
|
|
||||||
pub async fn send_lazy_request(&self, to_send: types::LazyRequest) {
|
|
||||||
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
|
||||||
|
|
||||||
trace!("GW: Sending Lazy Request..");
|
|
||||||
|
|
||||||
self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value)
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Closes the websocket connection and stops all gateway tasks;
|
|
||||||
///
|
|
||||||
/// Esentially pulls the plug on the gateway, leaving it possible to resume;
|
|
||||||
pub async fn close(&self) {
|
|
||||||
self.kill_send.send(()).unwrap();
|
|
||||||
self.websocket_send.lock().await.close().await.unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Gateway {
|
pub struct Gateway {
|
||||||
|
@ -709,10 +388,10 @@ impl Gateway {
|
||||||
| GATEWAY_REQUEST_GUILD_MEMBERS
|
| GATEWAY_REQUEST_GUILD_MEMBERS
|
||||||
| GATEWAY_CALL_SYNC
|
| GATEWAY_CALL_SYNC
|
||||||
| GATEWAY_LAZY_REQUEST => {
|
| GATEWAY_LAZY_REQUEST => {
|
||||||
let error = GatewayError::UnexpectedOpcodeReceived {
|
info!(
|
||||||
opcode: gateway_payload.op_code,
|
"Received unexpected opcode ({}) for current state. This might be due to a faulty server implementation and is likely not the fault of chorus.",
|
||||||
};
|
gateway_payload.op_code
|
||||||
Err::<(), GatewayError>(error).unwrap();
|
);
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
warn!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code);
|
warn!("Received unrecognized gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code);
|
||||||
|
@ -736,196 +415,7 @@ impl Gateway {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handles sending heartbeats to the gateway in another thread
|
pub mod event {
|
||||||
#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct HeartbeatHandler {
|
|
||||||
/// How ofter heartbeats need to be sent at a minimum
|
|
||||||
pub heartbeat_interval: Duration,
|
|
||||||
/// The send channel for the heartbeat thread
|
|
||||||
pub send: Sender<HeartbeatThreadCommunication>,
|
|
||||||
/// The handle of the thread
|
|
||||||
handle: JoinHandle<()>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl HeartbeatHandler {
|
|
||||||
pub fn new(
|
|
||||||
heartbeat_interval: Duration,
|
|
||||||
websocket_tx: Arc<
|
|
||||||
Mutex<
|
|
||||||
SplitSink<
|
|
||||||
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
|
||||||
tokio_tungstenite::tungstenite::Message,
|
|
||||||
>,
|
|
||||||
>,
|
|
||||||
>,
|
|
||||||
kill_rc: tokio::sync::broadcast::Receiver<()>,
|
|
||||||
) -> HeartbeatHandler {
|
|
||||||
let (send, receive) = tokio::sync::mpsc::channel(32);
|
|
||||||
let kill_receive = kill_rc.resubscribe();
|
|
||||||
|
|
||||||
let handle: JoinHandle<()> = task::spawn(async move {
|
|
||||||
HeartbeatHandler::heartbeat_task(
|
|
||||||
websocket_tx,
|
|
||||||
heartbeat_interval,
|
|
||||||
receive,
|
|
||||||
kill_receive,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
});
|
|
||||||
|
|
||||||
Self {
|
|
||||||
heartbeat_interval,
|
|
||||||
send,
|
|
||||||
handle,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The main heartbeat task;
|
|
||||||
///
|
|
||||||
/// Can be killed by the kill broadcast;
|
|
||||||
/// If the websocket is closed, will die out next time it tries to send a heartbeat;
|
|
||||||
pub async fn heartbeat_task(
|
|
||||||
websocket_tx: Arc<
|
|
||||||
Mutex<
|
|
||||||
SplitSink<
|
|
||||||
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
|
||||||
tokio_tungstenite::tungstenite::Message,
|
|
||||||
>,
|
|
||||||
>,
|
|
||||||
>,
|
|
||||||
heartbeat_interval: Duration,
|
|
||||||
mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>,
|
|
||||||
mut kill_receive: tokio::sync::broadcast::Receiver<()>,
|
|
||||||
) {
|
|
||||||
let mut last_heartbeat_timestamp: Instant = time::Instant::now();
|
|
||||||
let mut last_heartbeat_acknowledged = true;
|
|
||||||
let mut last_seq_number: Option<u64> = None;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
if kill_receive.try_recv().is_ok() {
|
|
||||||
trace!("GW: Closing heartbeat task");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let timeout = if last_heartbeat_acknowledged {
|
|
||||||
heartbeat_interval
|
|
||||||
} else {
|
|
||||||
// If the server hasn't acknowledged our heartbeat we should resend it
|
|
||||||
Duration::from_millis(HEARTBEAT_ACK_TIMEOUT)
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut should_send = false;
|
|
||||||
|
|
||||||
tokio::select! {
|
|
||||||
() = sleep_until(last_heartbeat_timestamp + timeout) => {
|
|
||||||
should_send = true;
|
|
||||||
}
|
|
||||||
Some(communication) = receive.recv() => {
|
|
||||||
// If we received a seq number update, use that as the last seq number
|
|
||||||
if communication.sequence_number.is_some() {
|
|
||||||
last_seq_number = communication.sequence_number;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(op_code) = communication.op_code {
|
|
||||||
match op_code {
|
|
||||||
GATEWAY_HEARTBEAT => {
|
|
||||||
// As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately
|
|
||||||
should_send = true;
|
|
||||||
}
|
|
||||||
GATEWAY_HEARTBEAT_ACK => {
|
|
||||||
// The server received our heartbeat
|
|
||||||
last_heartbeat_acknowledged = true;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if should_send {
|
|
||||||
trace!("GW: Sending Heartbeat..");
|
|
||||||
|
|
||||||
let heartbeat = types::GatewayHeartbeat {
|
|
||||||
op: GATEWAY_HEARTBEAT,
|
|
||||||
d: last_seq_number,
|
|
||||||
};
|
|
||||||
|
|
||||||
let heartbeat_json = serde_json::to_string(&heartbeat).unwrap();
|
|
||||||
|
|
||||||
let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
|
|
||||||
|
|
||||||
let send_result = websocket_tx.lock().await.send(msg).await;
|
|
||||||
if send_result.is_err() {
|
|
||||||
// We couldn't send, the websocket is broken
|
|
||||||
warn!("GW: Couldnt send heartbeat, websocket seems broken");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
last_heartbeat_timestamp = time::Instant::now();
|
|
||||||
last_heartbeat_acknowledged = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Used for communications between the heartbeat and gateway thread.
|
|
||||||
/// Either signifies a sequence number update, a heartbeat ACK or a Heartbeat request by the server
|
|
||||||
#[derive(Clone, Copy, Debug)]
|
|
||||||
struct HeartbeatThreadCommunication {
|
|
||||||
/// The opcode for the communication we received, if relevant
|
|
||||||
op_code: Option<u8>,
|
|
||||||
/// The sequence number we got from discord, if any
|
|
||||||
sequence_number: Option<u64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trait which defines the behavior of an Observer. An Observer is an object which is subscribed to
|
|
||||||
/// an Observable. The Observer is notified when the Observable's data changes.
|
|
||||||
/// In this case, the Observable is a [`GatewayEvent`], which is a wrapper around a WebSocketEvent.
|
|
||||||
/// Note that `Debug` is used to tell `Observer`s apart when unsubscribing.
|
|
||||||
#[async_trait]
|
|
||||||
pub trait Observer<T>: Sync + Send + std::fmt::Debug {
|
|
||||||
async fn update(&self, data: &T);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// GatewayEvent is a wrapper around a WebSocketEvent. It is used to notify the observers of a
|
|
||||||
/// change in the WebSocketEvent. GatewayEvents are observable.
|
|
||||||
#[derive(Default, Debug)]
|
|
||||||
pub struct GatewayEvent<T: WebSocketEvent> {
|
|
||||||
observers: Vec<Arc<dyn Observer<T>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: WebSocketEvent> GatewayEvent<T> {
|
|
||||||
/// Returns true if the GatewayEvent is observed by at least one Observer.
|
|
||||||
pub fn is_observed(&self) -> bool {
|
|
||||||
!self.observers.is_empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Subscribes an Observer to the GatewayEvent.
|
|
||||||
pub fn subscribe(&mut self, observable: Arc<dyn Observer<T>>) {
|
|
||||||
self.observers.push(observable);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Unsubscribes an Observer from the GatewayEvent.
|
|
||||||
pub fn unsubscribe(&mut self, observable: &dyn Observer<T>) {
|
|
||||||
// .retain()'s closure retains only those elements of the vector, which have a different
|
|
||||||
// pointer value than observable.
|
|
||||||
// The usage of the debug format to compare the generic T of observers is quite stupid, but the only thing to compare between them is T and if T == T they are the same
|
|
||||||
// anddd there is no way to do that without using format
|
|
||||||
let to_remove = format!("{:?}", observable);
|
|
||||||
self.observers
|
|
||||||
.retain(|obs| format!("{:?}", obs) != to_remove);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Notifies the observers of the GatewayEvent.
|
|
||||||
async fn notify(&self, new_event_data: T) {
|
|
||||||
for observer in &self.observers {
|
|
||||||
observer.update(&new_event_data).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod events {
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[derive(Default, Debug)]
|
#[derive(Default, Debug)]
|
||||||
|
@ -1086,52 +576,3 @@ pub mod events {
|
||||||
pub update: GatewayEvent<types::WebhooksUpdate>,
|
pub update: GatewayEvent<types::WebhooksUpdate>,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod example {
|
|
||||||
use super::*;
|
|
||||||
use std::sync::atomic::{AtomicI32, Ordering::Relaxed};
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Consumer {
|
|
||||||
_name: String,
|
|
||||||
events_received: AtomicI32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Observer<types::GatewayResume> for Consumer {
|
|
||||||
async fn update(&self, _data: &types::GatewayResume) {
|
|
||||||
self.events_received.fetch_add(1, Relaxed);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_observer_behavior() {
|
|
||||||
let mut event = GatewayEvent::default();
|
|
||||||
|
|
||||||
let new_data = types::GatewayResume {
|
|
||||||
token: "token_3276ha37am3".to_string(),
|
|
||||||
session_id: "89346671230".to_string(),
|
|
||||||
seq: "3".to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let consumer = Arc::new(Consumer {
|
|
||||||
_name: "first".into(),
|
|
||||||
events_received: 0.into(),
|
|
||||||
});
|
|
||||||
event.subscribe(consumer.clone());
|
|
||||||
|
|
||||||
let second_consumer = Arc::new(Consumer {
|
|
||||||
_name: "second".into(),
|
|
||||||
events_received: 0.into(),
|
|
||||||
});
|
|
||||||
event.subscribe(second_consumer.clone());
|
|
||||||
|
|
||||||
event.notify(new_data.clone()).await;
|
|
||||||
event.unsubscribe(&*consumer);
|
|
||||||
event.notify(new_data).await;
|
|
||||||
|
|
||||||
assert_eq!(consumer.events_received.load(Relaxed), 1);
|
|
||||||
assert_eq!(second_consumer.events_received.load(Relaxed), 2);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,171 @@
|
||||||
|
use super::{event::Events, *};
|
||||||
|
use crate::types::{self, Composite};
|
||||||
|
|
||||||
|
/// Represents a handle to a Gateway connection. A Gateway connection will create observable
|
||||||
|
/// [`GatewayEvents`](GatewayEvent), which you can subscribe to. Gateway events include all currently
|
||||||
|
/// implemented types with the trait [`WebSocketEvent`]
|
||||||
|
/// Using this handle you can also send Gateway Events directly.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct GatewayHandle {
|
||||||
|
pub url: String,
|
||||||
|
pub events: Arc<Mutex<Events>>,
|
||||||
|
pub websocket_send: Arc<
|
||||||
|
Mutex<
|
||||||
|
SplitSink<
|
||||||
|
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||||
|
tokio_tungstenite::tungstenite::Message,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
/// Tells gateway tasks to close
|
||||||
|
pub(super) kill_send: tokio::sync::broadcast::Sender<()>,
|
||||||
|
pub(crate) store: Arc<Mutex<HashMap<Snowflake, Arc<RwLock<ObservableObject>>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GatewayHandle {
|
||||||
|
/// Sends json to the gateway with an opcode
|
||||||
|
async fn send_json_event(&self, op_code: u8, to_send: serde_json::Value) {
|
||||||
|
let gateway_payload = types::GatewaySendPayload {
|
||||||
|
op_code,
|
||||||
|
event_data: Some(to_send),
|
||||||
|
sequence_number: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload_json = serde_json::to_string(&gateway_payload).unwrap();
|
||||||
|
|
||||||
|
let message = tokio_tungstenite::tungstenite::Message::text(payload_json);
|
||||||
|
|
||||||
|
self.websocket_send
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.send(message)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn observe<T: Updateable + Clone + Debug + Composite<T>>(
|
||||||
|
&self,
|
||||||
|
object: Arc<RwLock<T>>,
|
||||||
|
) -> Arc<RwLock<T>> {
|
||||||
|
let mut store = self.store.lock().await;
|
||||||
|
let id = object.read().unwrap().id();
|
||||||
|
if let Some(channel) = store.get(&id) {
|
||||||
|
let object = channel.clone();
|
||||||
|
drop(store);
|
||||||
|
object
|
||||||
|
.read()
|
||||||
|
.unwrap()
|
||||||
|
.downcast_ref::<T>()
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
panic!(
|
||||||
|
"Snowflake {} already exists in the store, but it is not of type T.",
|
||||||
|
id
|
||||||
|
)
|
||||||
|
});
|
||||||
|
let ptr = Arc::into_raw(object.clone());
|
||||||
|
// SAFETY:
|
||||||
|
// - We have just checked that the typeid of the `dyn Any ...` matches that of `T`.
|
||||||
|
// - This operation doesn't read or write any shared data, and thus cannot cause a data race
|
||||||
|
// - The reference count is not being modified
|
||||||
|
let downcasted = unsafe { Arc::from_raw(ptr as *const RwLock<T>).clone() };
|
||||||
|
let object = downcasted.read().unwrap().clone();
|
||||||
|
|
||||||
|
let watched_object = object.watch_whole(self).await;
|
||||||
|
*downcasted.write().unwrap() = watched_object;
|
||||||
|
downcasted
|
||||||
|
} else {
|
||||||
|
let id = object.read().unwrap().id();
|
||||||
|
let object = object.read().unwrap().clone();
|
||||||
|
let object = object.clone().watch_whole(self).await;
|
||||||
|
let wrapped = Arc::new(RwLock::new(object));
|
||||||
|
store.insert(id, wrapped.clone());
|
||||||
|
wrapped
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Recursively observes and updates all updateable fields on the struct T. Returns an object `T`
|
||||||
|
/// with all of its observable fields being observed.
|
||||||
|
pub async fn observe_and_into_inner<T: Updateable + Clone + Debug + Composite<T>>(
|
||||||
|
&self,
|
||||||
|
object: Arc<RwLock<T>>,
|
||||||
|
) -> T {
|
||||||
|
let channel = self.observe(object.clone()).await;
|
||||||
|
let object = channel.read().unwrap().clone();
|
||||||
|
object
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sends an identify event to the gateway
|
||||||
|
pub async fn send_identify(&self, to_send: types::GatewayIdentifyPayload) {
|
||||||
|
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
||||||
|
|
||||||
|
trace!("GW: Sending Identify..");
|
||||||
|
|
||||||
|
self.send_json_event(GATEWAY_IDENTIFY, to_send_value).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sends a resume event to the gateway
|
||||||
|
pub async fn send_resume(&self, to_send: types::GatewayResume) {
|
||||||
|
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
||||||
|
|
||||||
|
trace!("GW: Sending Resume..");
|
||||||
|
|
||||||
|
self.send_json_event(GATEWAY_RESUME, to_send_value).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sends an update presence event to the gateway
|
||||||
|
pub async fn send_update_presence(&self, to_send: types::UpdatePresence) {
|
||||||
|
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
||||||
|
|
||||||
|
trace!("GW: Sending Update Presence..");
|
||||||
|
|
||||||
|
self.send_json_event(GATEWAY_UPDATE_PRESENCE, to_send_value)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sends a request guild members to the server
|
||||||
|
pub async fn send_request_guild_members(&self, to_send: types::GatewayRequestGuildMembers) {
|
||||||
|
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
||||||
|
|
||||||
|
trace!("GW: Sending Request Guild Members..");
|
||||||
|
|
||||||
|
self.send_json_event(GATEWAY_REQUEST_GUILD_MEMBERS, to_send_value)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sends an update voice state to the server
|
||||||
|
pub async fn send_update_voice_state(&self, to_send: types::UpdateVoiceState) {
|
||||||
|
let to_send_value = serde_json::to_value(to_send).unwrap();
|
||||||
|
|
||||||
|
trace!("GW: Sending Update Voice State..");
|
||||||
|
|
||||||
|
self.send_json_event(GATEWAY_UPDATE_VOICE_STATE, to_send_value)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sends a call sync to the server
|
||||||
|
pub async fn send_call_sync(&self, to_send: types::CallSync) {
|
||||||
|
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
||||||
|
|
||||||
|
trace!("GW: Sending Call Sync..");
|
||||||
|
|
||||||
|
self.send_json_event(GATEWAY_CALL_SYNC, to_send_value).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sends a Lazy Request
|
||||||
|
pub async fn send_lazy_request(&self, to_send: types::LazyRequest) {
|
||||||
|
let to_send_value = serde_json::to_value(&to_send).unwrap();
|
||||||
|
|
||||||
|
trace!("GW: Sending Lazy Request..");
|
||||||
|
|
||||||
|
self.send_json_event(GATEWAY_LAZY_REQUEST, to_send_value)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Closes the websocket connection and stops all gateway tasks;
|
||||||
|
///
|
||||||
|
/// Esentially pulls the plug on the gateway, leaving it possible to resume;
|
||||||
|
pub async fn close(&self) {
|
||||||
|
self.kill_send.send(()).unwrap();
|
||||||
|
self.websocket_send.lock().await.close().await.unwrap();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,149 @@
|
||||||
|
use crate::types;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
/// The amount of time we wait for a heartbeat ack before resending our heartbeat in ms
|
||||||
|
const HEARTBEAT_ACK_TIMEOUT: u64 = 2000;
|
||||||
|
|
||||||
|
/// Handles sending heartbeats to the gateway in another thread
|
||||||
|
#[allow(dead_code)] // FIXME: Remove this, once HeartbeatHandler is used
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(super) struct HeartbeatHandler {
|
||||||
|
/// How ofter heartbeats need to be sent at a minimum
|
||||||
|
pub heartbeat_interval: Duration,
|
||||||
|
/// The send channel for the heartbeat thread
|
||||||
|
pub send: Sender<HeartbeatThreadCommunication>,
|
||||||
|
/// The handle of the thread
|
||||||
|
handle: JoinHandle<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HeartbeatHandler {
|
||||||
|
pub fn new(
|
||||||
|
heartbeat_interval: Duration,
|
||||||
|
websocket_tx: Arc<
|
||||||
|
Mutex<
|
||||||
|
SplitSink<
|
||||||
|
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||||
|
tokio_tungstenite::tungstenite::Message,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
kill_rc: tokio::sync::broadcast::Receiver<()>,
|
||||||
|
) -> HeartbeatHandler {
|
||||||
|
let (send, receive) = tokio::sync::mpsc::channel(32);
|
||||||
|
let kill_receive = kill_rc.resubscribe();
|
||||||
|
|
||||||
|
let handle: JoinHandle<()> = task::spawn(async move {
|
||||||
|
HeartbeatHandler::heartbeat_task(
|
||||||
|
websocket_tx,
|
||||||
|
heartbeat_interval,
|
||||||
|
receive,
|
||||||
|
kill_receive,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
Self {
|
||||||
|
heartbeat_interval,
|
||||||
|
send,
|
||||||
|
handle,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The main heartbeat task;
|
||||||
|
///
|
||||||
|
/// Can be killed by the kill broadcast;
|
||||||
|
/// If the websocket is closed, will die out next time it tries to send a heartbeat;
|
||||||
|
pub async fn heartbeat_task(
|
||||||
|
websocket_tx: Arc<
|
||||||
|
Mutex<
|
||||||
|
SplitSink<
|
||||||
|
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||||
|
tokio_tungstenite::tungstenite::Message,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
heartbeat_interval: Duration,
|
||||||
|
mut receive: tokio::sync::mpsc::Receiver<HeartbeatThreadCommunication>,
|
||||||
|
mut kill_receive: tokio::sync::broadcast::Receiver<()>,
|
||||||
|
) {
|
||||||
|
let mut last_heartbeat_timestamp: Instant = time::Instant::now();
|
||||||
|
let mut last_heartbeat_acknowledged = true;
|
||||||
|
let mut last_seq_number: Option<u64> = None;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if kill_receive.try_recv().is_ok() {
|
||||||
|
trace!("GW: Closing heartbeat task");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let timeout = if last_heartbeat_acknowledged {
|
||||||
|
heartbeat_interval
|
||||||
|
} else {
|
||||||
|
// If the server hasn't acknowledged our heartbeat we should resend it
|
||||||
|
Duration::from_millis(HEARTBEAT_ACK_TIMEOUT)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut should_send = false;
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
() = sleep_until(last_heartbeat_timestamp + timeout) => {
|
||||||
|
should_send = true;
|
||||||
|
}
|
||||||
|
Some(communication) = receive.recv() => {
|
||||||
|
// If we received a seq number update, use that as the last seq number
|
||||||
|
if communication.sequence_number.is_some() {
|
||||||
|
last_seq_number = communication.sequence_number;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(op_code) = communication.op_code {
|
||||||
|
match op_code {
|
||||||
|
GATEWAY_HEARTBEAT => {
|
||||||
|
// As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately
|
||||||
|
should_send = true;
|
||||||
|
}
|
||||||
|
GATEWAY_HEARTBEAT_ACK => {
|
||||||
|
// The server received our heartbeat
|
||||||
|
last_heartbeat_acknowledged = true;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if should_send {
|
||||||
|
trace!("GW: Sending Heartbeat..");
|
||||||
|
|
||||||
|
let heartbeat = types::GatewayHeartbeat {
|
||||||
|
op: GATEWAY_HEARTBEAT,
|
||||||
|
d: last_seq_number,
|
||||||
|
};
|
||||||
|
|
||||||
|
let heartbeat_json = serde_json::to_string(&heartbeat).unwrap();
|
||||||
|
|
||||||
|
let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
|
||||||
|
|
||||||
|
let send_result = websocket_tx.lock().await.send(msg).await;
|
||||||
|
if send_result.is_err() {
|
||||||
|
// We couldn't send, the websocket is broken
|
||||||
|
warn!("GW: Couldnt send heartbeat, websocket seems broken");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
last_heartbeat_timestamp = time::Instant::now();
|
||||||
|
last_heartbeat_acknowledged = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Used for communications between the heartbeat and gateway thread.
|
||||||
|
/// Either signifies a sequence number update, a heartbeat ACK or a Heartbeat request by the server
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub(super) struct HeartbeatThreadCommunication {
|
||||||
|
/// The opcode for the communication we received, if relevant
|
||||||
|
pub(super) op_code: Option<u8>,
|
||||||
|
/// The sequence number we got from discord, if any
|
||||||
|
pub(super) sequence_number: Option<u64>,
|
||||||
|
}
|
|
@ -0,0 +1,73 @@
|
||||||
|
use crate::types;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
/// Represents a messsage received from the gateway. This will be either a [types::GatewayReceivePayload], containing events, or a [GatewayError].
|
||||||
|
/// This struct is used internally when handling messages.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct GatewayMessage {
|
||||||
|
/// The message we received from the server
|
||||||
|
pub(super) message: tokio_tungstenite::tungstenite::Message,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GatewayMessage {
|
||||||
|
/// Creates self from a tungstenite message
|
||||||
|
pub fn from_tungstenite_message(message: tokio_tungstenite::tungstenite::Message) -> Self {
|
||||||
|
Self { message }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parses the message as an error;
|
||||||
|
/// Returns the error if succesfully parsed, None if the message isn't an error
|
||||||
|
pub fn error(&self) -> Option<GatewayError> {
|
||||||
|
let content = self.message.to_string();
|
||||||
|
|
||||||
|
// Some error strings have dots on the end, which we don't care about
|
||||||
|
let processed_content = content.to_lowercase().replace('.', "");
|
||||||
|
|
||||||
|
match processed_content.as_str() {
|
||||||
|
"unknown error" | "4000" => Some(GatewayError::Unknown),
|
||||||
|
"unknown opcode" | "4001" => Some(GatewayError::UnknownOpcode),
|
||||||
|
"decode error" | "error while decoding payload" | "4002" => Some(GatewayError::Decode),
|
||||||
|
"not authenticated" | "4003" => Some(GatewayError::NotAuthenticated),
|
||||||
|
"authentication failed" | "4004" => Some(GatewayError::AuthenticationFailed),
|
||||||
|
"already authenticated" | "4005" => Some(GatewayError::AlreadyAuthenticated),
|
||||||
|
"invalid seq" | "4007" => Some(GatewayError::InvalidSequenceNumber),
|
||||||
|
"rate limited" | "4008" => Some(GatewayError::RateLimited),
|
||||||
|
"session timed out" | "4009" => Some(GatewayError::SessionTimedOut),
|
||||||
|
"invalid shard" | "4010" => Some(GatewayError::InvalidShard),
|
||||||
|
"sharding required" | "4011" => Some(GatewayError::ShardingRequired),
|
||||||
|
"invalid api version" | "4012" => Some(GatewayError::InvalidAPIVersion),
|
||||||
|
"invalid intent(s)" | "invalid intent" | "4013" => Some(GatewayError::InvalidIntents),
|
||||||
|
"disallowed intent(s)" | "disallowed intents" | "4014" => {
|
||||||
|
Some(GatewayError::DisallowedIntents)
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns whether or not the message is an error
|
||||||
|
pub fn is_error(&self) -> bool {
|
||||||
|
self.error().is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parses the message as a payload;
|
||||||
|
/// Returns a result of deserializing
|
||||||
|
pub fn payload(&self) -> Result<types::GatewayReceivePayload, serde_json::Error> {
|
||||||
|
return serde_json::from_str(self.message.to_text().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns whether or not the message is a payload
|
||||||
|
pub fn is_payload(&self) -> bool {
|
||||||
|
// close messages are never payloads, payloads are only text messages
|
||||||
|
if self.message.is_close() | !self.message.is_text() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.payload().is_ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns whether or not the message is empty
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.message.is_empty()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,187 @@
|
||||||
|
pub mod gateway;
|
||||||
|
pub mod handle;
|
||||||
|
pub mod heartbeat;
|
||||||
|
pub mod message;
|
||||||
|
|
||||||
|
pub use gateway::*;
|
||||||
|
pub use handle::*;
|
||||||
|
use heartbeat::*;
|
||||||
|
pub use message::*;
|
||||||
|
|
||||||
|
use crate::errors::GatewayError;
|
||||||
|
use crate::types::{Snowflake, WebSocketEvent};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::any::Any;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::fmt::Debug;
|
||||||
|
use std::sync::{Arc, RwLock};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::time::sleep_until;
|
||||||
|
|
||||||
|
use futures_util::stream::SplitSink;
|
||||||
|
use futures_util::stream::SplitStream;
|
||||||
|
use futures_util::SinkExt;
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
use log::{info, trace, warn};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::sync::mpsc::Sender;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tokio::task;
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
use tokio::time;
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tokio_tungstenite::MaybeTlsStream;
|
||||||
|
use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream};
|
||||||
|
|
||||||
|
// Gateway opcodes
|
||||||
|
/// Opcode received when the server dispatches a [crate::types::WebSocketEvent]
|
||||||
|
const GATEWAY_DISPATCH: u8 = 0;
|
||||||
|
/// Opcode sent when sending a heartbeat
|
||||||
|
const GATEWAY_HEARTBEAT: u8 = 1;
|
||||||
|
/// Opcode sent to initiate a session
|
||||||
|
///
|
||||||
|
/// See [types::GatewayIdentifyPayload]
|
||||||
|
const GATEWAY_IDENTIFY: u8 = 2;
|
||||||
|
/// Opcode sent to update our presence
|
||||||
|
///
|
||||||
|
/// See [types::GatewayUpdatePresence]
|
||||||
|
const GATEWAY_UPDATE_PRESENCE: u8 = 3;
|
||||||
|
/// Opcode sent to update our state in vc
|
||||||
|
///
|
||||||
|
/// Like muting, deafening, leaving, joining..
|
||||||
|
///
|
||||||
|
/// See [types::UpdateVoiceState]
|
||||||
|
const GATEWAY_UPDATE_VOICE_STATE: u8 = 4;
|
||||||
|
/// Opcode sent to resume a session
|
||||||
|
///
|
||||||
|
/// See [types::GatewayResume]
|
||||||
|
const GATEWAY_RESUME: u8 = 6;
|
||||||
|
/// Opcode received to tell the client to reconnect
|
||||||
|
const GATEWAY_RECONNECT: u8 = 7;
|
||||||
|
/// Opcode sent to request guild member data
|
||||||
|
///
|
||||||
|
/// See [types::GatewayRequestGuildMembers]
|
||||||
|
const GATEWAY_REQUEST_GUILD_MEMBERS: u8 = 8;
|
||||||
|
/// Opcode received to tell the client their token / session is invalid
|
||||||
|
const GATEWAY_INVALID_SESSION: u8 = 9;
|
||||||
|
/// Opcode received when initially connecting to the gateway, starts our heartbeat
|
||||||
|
///
|
||||||
|
/// See [types::HelloData]
|
||||||
|
const GATEWAY_HELLO: u8 = 10;
|
||||||
|
/// Opcode received to acknowledge a heartbeat
|
||||||
|
const GATEWAY_HEARTBEAT_ACK: u8 = 11;
|
||||||
|
/// Opcode sent to get the voice state of users in a given DM/group channel
|
||||||
|
///
|
||||||
|
/// See [types::CallSync]
|
||||||
|
const GATEWAY_CALL_SYNC: u8 = 13;
|
||||||
|
/// Opcode sent to get data for a server (Lazy Loading request)
|
||||||
|
///
|
||||||
|
/// Sent by the official client when switching to a server
|
||||||
|
///
|
||||||
|
/// See [types::LazyRequest]
|
||||||
|
const GATEWAY_LAZY_REQUEST: u8 = 14;
|
||||||
|
|
||||||
|
pub type ObservableObject = dyn Send + Sync + Any;
|
||||||
|
|
||||||
|
/// An entity type which is supposed to be updateable via the Gateway. This is implemented for all such types chorus supports, implementing it for your own types is likely a mistake.
|
||||||
|
pub trait Updateable: 'static + Send + Sync {
|
||||||
|
fn id(&self) -> Snowflake;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trait which defines the behavior of an Observer. An Observer is an object which is subscribed to
|
||||||
|
/// an Observable. The Observer is notified when the Observable's data changes.
|
||||||
|
/// In this case, the Observable is a [`GatewayEvent`], which is a wrapper around a WebSocketEvent.
|
||||||
|
/// Note that `Debug` is used to tell `Observer`s apart when unsubscribing.
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Observer<T>: Sync + Send + std::fmt::Debug {
|
||||||
|
async fn update(&self, data: &T);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GatewayEvent is a wrapper around a WebSocketEvent. It is used to notify the observers of a
|
||||||
|
/// change in the WebSocketEvent. GatewayEvents are observable.
|
||||||
|
#[derive(Default, Debug)]
|
||||||
|
pub struct GatewayEvent<T: WebSocketEvent> {
|
||||||
|
observers: Vec<Arc<dyn Observer<T>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: WebSocketEvent> GatewayEvent<T> {
|
||||||
|
/// Returns true if the GatewayEvent is observed by at least one Observer.
|
||||||
|
pub fn is_observed(&self) -> bool {
|
||||||
|
!self.observers.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Subscribes an Observer to the GatewayEvent.
|
||||||
|
pub fn subscribe(&mut self, observable: Arc<dyn Observer<T>>) {
|
||||||
|
self.observers.push(observable);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unsubscribes an Observer from the GatewayEvent.
|
||||||
|
pub fn unsubscribe(&mut self, observable: &dyn Observer<T>) {
|
||||||
|
// .retain()'s closure retains only those elements of the vector, which have a different
|
||||||
|
// pointer value than observable.
|
||||||
|
// The usage of the debug format to compare the generic T of observers is quite stupid, but the only thing to compare between them is T and if T == T they are the same
|
||||||
|
// anddd there is no way to do that without using format
|
||||||
|
let to_remove = format!("{:?}", observable);
|
||||||
|
self.observers
|
||||||
|
.retain(|obs| format!("{:?}", obs) != to_remove);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Notifies the observers of the GatewayEvent.
|
||||||
|
async fn notify(&self, new_event_data: T) {
|
||||||
|
for observer in &self.observers {
|
||||||
|
observer.update(&new_event_data).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod example {
|
||||||
|
use crate::types;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use std::sync::atomic::{AtomicI32, Ordering::Relaxed};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Consumer {
|
||||||
|
_name: String,
|
||||||
|
events_received: AtomicI32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Observer<types::GatewayResume> for Consumer {
|
||||||
|
async fn update(&self, _data: &types::GatewayResume) {
|
||||||
|
self.events_received.fetch_add(1, Relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_observer_behavior() {
|
||||||
|
let mut event = GatewayEvent::default();
|
||||||
|
|
||||||
|
let new_data = types::GatewayResume {
|
||||||
|
token: "token_3276ha37am3".to_string(),
|
||||||
|
session_id: "89346671230".to_string(),
|
||||||
|
seq: "3".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let consumer = Arc::new(Consumer {
|
||||||
|
_name: "first".into(),
|
||||||
|
events_received: 0.into(),
|
||||||
|
});
|
||||||
|
event.subscribe(consumer.clone());
|
||||||
|
|
||||||
|
let second_consumer = Arc::new(Consumer {
|
||||||
|
_name: "second".into(),
|
||||||
|
events_received: 0.into(),
|
||||||
|
});
|
||||||
|
event.subscribe(second_consumer.clone());
|
||||||
|
|
||||||
|
event.notify(new_data.clone()).await;
|
||||||
|
event.unsubscribe(&*consumer);
|
||||||
|
event.notify(new_data).await;
|
||||||
|
|
||||||
|
assert_eq!(consumer.events_received.load(Relaxed), 1);
|
||||||
|
assert_eq!(second_consumer.events_received.load(Relaxed), 2);
|
||||||
|
}
|
||||||
|
}
|
|
@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};
|
||||||
use crate::types::{Snowflake, WelcomeScreenObject};
|
use crate::types::{Snowflake, WelcomeScreenObject};
|
||||||
|
|
||||||
use super::guild::GuildScheduledEvent;
|
use super::guild::GuildScheduledEvent;
|
||||||
use super::{Application, Channel, GuildMember, User};
|
use super::{Application, Channel, GuildMember, NSFWLevel, User};
|
||||||
|
|
||||||
/// Represents a code that when used, adds a user to a guild or group DM channel, or creates a relationship between two users.
|
/// Represents a code that when used, adds a user to a guild or group DM channel, or creates a relationship between two users.
|
||||||
/// See <https://discord-userdoccers.vercel.app/resources/invite#invite-object>
|
/// See <https://discord-userdoccers.vercel.app/resources/invite#invite-object>
|
||||||
|
@ -56,17 +56,6 @@ pub struct InviteGuild {
|
||||||
pub welcome_screen: Option<WelcomeScreenObject>,
|
pub welcome_screen: Option<WelcomeScreenObject>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// See <https://discord-userdoccers.vercel.app/resources/guild#nsfw-level> for an explanation on what
|
|
||||||
/// the levels mean.
|
|
||||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
|
||||||
pub enum NSFWLevel {
|
|
||||||
Default = 0,
|
|
||||||
Explicit = 1,
|
|
||||||
Safe = 2,
|
|
||||||
AgeRestricted = 3,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// See <https://discord-userdoccers.vercel.app/resources/invite#invite-stage-instance-object>
|
/// See <https://discord-userdoccers.vercel.app/resources/invite#invite-stage-instance-object>
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct InviteStageInstance {
|
pub struct InviteStageInstance {
|
||||||
|
|
Loading…
Reference in New Issue