diff --git a/src/api/auth/login.rs b/src/api/auth/login.rs index 67647f7..2f77158 100644 --- a/src/api/auth/login.rs +++ b/src/api/auth/login.rs @@ -8,7 +8,7 @@ pub mod login { use crate::errors::InstanceServerError; use crate::instance::Instance; - impl Instance { + impl<'a> Instance<'a> { pub async fn login_account( &mut self, login_schema: &LoginSchema, diff --git a/src/api/auth/register.rs b/src/api/auth/register.rs index 186be78..4a3c71a 100644 --- a/src/api/auth/register.rs +++ b/src/api/auth/register.rs @@ -8,7 +8,7 @@ pub mod register { instance::Instance, }; - impl Instance { + impl<'a> Instance<'a> { /** Registers a new user on the Spacebar server. # Arguments diff --git a/src/api/policies/instance/instance.rs b/src/api/policies/instance/instance.rs index b13aa89..e08841f 100644 --- a/src/api/policies/instance/instance.rs +++ b/src/api/policies/instance/instance.rs @@ -53,4 +53,4 @@ mod instance_policies_schema_test { let schema = test_instance.instance_policies_schema().await.unwrap(); } -} +} \ No newline at end of file diff --git a/src/api/types.rs b/src/api/types.rs index e52a2f8..f94b458 100644 --- a/src/api/types.rs +++ b/src/api/types.rs @@ -132,6 +132,12 @@ pub struct Error { pub code: String, } +#[derive(Serialize, Deserialize, Debug, Default)] +pub struct UnavailableGuild { + id: String, + unavailable: bool +} + #[derive(Serialize, Deserialize, Debug, Default)] pub struct UserObject { pub id: String, @@ -717,7 +723,7 @@ pub struct PresenceUpdate { since: Option, activities: Vec, status: String, - afk: bool, + afk: Option, } impl WebSocketEvent for PresenceUpdate {} @@ -785,6 +791,18 @@ pub struct GatewayResume { impl WebSocketEvent for GatewayResume {} +#[derive(Debug, Deserialize, Serialize, Default)] +pub struct GatewayReady { + pub v: u8, + pub user: UserObject, + pub guilds: Vec, + pub session_id: String, + pub resume_gateway_url: Option, + pub shard: Option<(u64, u64)>, +} + +impl WebSocketEvent for GatewayReady {} + #[derive(Debug, Default, Deserialize, Serialize)] pub struct GatewayHello { pub op: i32, @@ -795,7 +813,7 @@ impl WebSocketEvent for GatewayHello {} #[derive(Debug, Default, Deserialize, Serialize)] pub struct HelloData { - pub heartbeat_interval: i32, + pub heartbeat_interval: u128, } impl WebSocketEvent for HelloData {} @@ -803,7 +821,7 @@ impl WebSocketEvent for HelloData {} #[derive(Debug, Default, Deserialize, Serialize)] pub struct GatewayHeartbeat { pub op: u8, - pub d: u64, + pub d: Option, } impl WebSocketEvent for GatewayHeartbeat {} @@ -817,9 +835,9 @@ impl WebSocketEvent for GatewayHeartbeatAck {} #[derive(Debug, Default, Deserialize, Serialize)] pub struct GatewayPayload { - pub op: i32, - pub d: Option, - pub s: Option, + pub op: u8, + pub d: Option, + pub s: Option, pub t: Option, } diff --git a/src/gateway.rs b/src/gateway.rs index 57860bb..7f31fe2 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -1,7 +1,22 @@ +use std::sync::Arc; use crate::api::types::*; use crate::api::WebSocketEvent; use crate::errors::ObserverError; use crate::gateway::events::Events; +use futures_util::SinkExt; +use futures_util::StreamExt; +use futures_util::stream::SplitSink; +use native_tls::TlsConnector; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::TryRecvError; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::Mutex; +use tokio::task; +use tokio::time; +use tokio::time::Instant; +use tokio_tungstenite::MaybeTlsStream; +use tokio_tungstenite::{WebSocketStream, Connector, connect_async_tls_with_config}; /** Represents a Gateway connection. A Gateway connection will create observable @@ -10,78 +25,332 @@ implemented [Types] with the trait [`WebSocketEvent`] */ pub struct Gateway<'a> { pub url: String, - pub token: String, pub events: Events<'a>, + websocket: WebSocketConnection, + heartbeat_handler: Option } impl<'a> Gateway<'a> { pub async fn new( websocket_url: String, - token: String, ) -> Result, tokio_tungstenite::tungstenite::Error> { - Ok(Gateway { - url: websocket_url, - token, + return Ok(Gateway { + url: websocket_url.clone(), events: Events::default(), - }) + websocket: WebSocketConnection::new(websocket_url).await, + heartbeat_handler: None, + }); + } + + /// This function reads all messages from the gateway's websocket and updates its events along with the events' observers + pub async fn update_events(&mut self) { + + while let Ok(msg) = self.websocket.rx.lock().await.try_recv() { + + if msg.to_string() == String::new() { + continue; + } + + let gateway_payload: GatewayPayload = serde_json::from_str(msg.to_text().unwrap()).unwrap(); + + // See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes + match gateway_payload.op { + // Dispatch + // An event was dispatched, we need to look at the gateway event name t + 0 => { + let gateway_payload_t = gateway_payload.t.unwrap(); + + println!("GW: Received {}..", gateway_payload_t); + + // See https://discord.com/developers/docs/topics/gateway-events#receive-events + match gateway_payload_t.as_str() { + "READY" => { + let data: GatewayReady = serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); + } + "APPLICATION_COMMAND_PERMISSIONS_UPDATE" => {} + "AUTO_MODERATION_RULE_CREATE" => {} + "AUTO_MODERATION_RULE_UPDATE" => {} + "AUTO_MODERATION_RULE_DELETE" => {} + "AUTO_MODERATION_ACTION_EXECUTION" => {} + "CHANNEL_CREATE" => {} + "CHANNEL_UPDATE" => {} + "CHANNEL_DELETE" => {} + "CHANNEL_PINS_UPDATE" => {} + "THREAD_CREATE" => {} + "THREAD_UPDATE" => {} + "THREAD_DELETE" => {} + "THREAD_LIST_SYNC" => {} + "THREAD_MEMBER_UPDATE" => {} + "THREAD_MEMBERS_UPDATE" => {} + "GUILD_CREATE" => {} + "GUILD_UPDATE" => {} + "GUILD_DELETE" => {} + "GUILD_AUDIT_LOG_ENTRY_CREATE" => {} + "GUILD_BAN_ADD" => {} + "GUILD_BAN_REMOVE" => {} + "GUILD_EMOJIS_UPDATE" => {} + "GUILD_STICKERS_UPDATE" => {} + "GUILD_INTEGRATIONS_UPDATE" => {} + "GUILD_MEMBER_ADD" => {} + "GUILD_MEMBER_REMOVE" => {} + "GUILD_MEMBER_UPDATE" => {} + "GUILD_MEMBERS_CHUNK" => {} + "GUILD_ROLE_CREATE" => {} + "GUILD_ROLE_UPDATE" => {} + "GUILD_ROLE_DELETE" => {} + "GUILD_SCHEDULED_EVENT_CREATE" => {} + "GUILD_SCHEDULED_EVENT_UPDATE" => {} + "GUILD_SCHEDULED_EVENT_DELETE" => {} + "GUILD_SCHEDULED_EVENT_USER_ADD" => {} + "GUILD_SCHEDULED_EVENT_USER_REMOVE" => {} + "INTEGRATION_CREATE" => {} + "INTEGRATION_UPDATE" => {} + "INTEGRATION_DELETE" => {} + "INTERACTION_CREATE" => {} + "INVITE_CREATE" => {} + "INVITE_DELETE" => {} + "MESSAGE_CREATE" => { + let new_data: MessageCreate = serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); + self.events.message.create.update_data(new_data); + } + "MESSAGE_UPDATE" => { + let new_data: MessageUpdate = serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); + self.events.message.update.update_data(new_data); + } + "MESSAGE_DELETE" => { + let new_data: MessageDelete = serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); + self.events.message.delete.update_data(new_data); + } + "MESSAGE_DELETE_BULK" => { + let new_data: MessageDeleteBulk = serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); + self.events.message.delete_bulk.update_data(new_data); + } + "MESSAGE_REACTION_ADD" => { + let new_data: MessageReactionAdd = serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); + self.events.message.reaction_add.update_data(new_data); + } + "MESSAGE_REACTION_REMOVE" => { + let new_data: MessageReactionRemove = serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); + self.events.message.reaction_remove.update_data(new_data); + } + "MESSAGE_REACTION_REMOVE_ALL" => { + let new_data: MessageReactionRemoveAll = serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); + self.events.message.reaction_remove_all.update_data(new_data); + } + "MESSAGE_REACTION_REMOVE_EMOJI" => { + let new_data: MessageReactionRemoveEmoji= serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); + self.events.message.reaction_remove_emoji.update_data(new_data); + } + "PRESENCE_UPDATE" => { + let new_data: PresenceUpdate = serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); + self.events.user.presence_update.update_data(new_data); + } + "STAGE_INSTANCE_CREATE" => {} + "STAGE_INSTANCE_UPDATE" => {} + "STAGE_INSTANCE_DELETE" => {} + // Not documented in discord docs, I assume this isnt for bots / apps but is for users? + "SESSIONS_REPLACE" => {} + "TYPING_START" => { + let new_data: TypingStartEvent = serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); + self.events.user.typing_start_event.update_data(new_data); + } + "USER_UPDATE" => {} + "VOICE_STATE_UPDATE" => {} + "VOICE_SERVER_UPDATE" => {} + "WEBHOOKS_UPDATE" => {} + _ => {panic!("Invalid gateway event ({})", &gateway_payload_t)} + } + } + // Heartbeat + // We received a heartbeat from the server + 1 => {} + // Reconnect + 7 => {todo!()} + // Invalid Session + 9 => {todo!()} + // Hello + // Starts our heartbeat + 10 => { + println!("GW: Received Hello"); + let gateway_hello: HelloData = serde_json::from_value(gateway_payload.d.unwrap()).unwrap(); + self.heartbeat_handler = Some(HeartbeatHandler::new(gateway_hello.heartbeat_interval, self.websocket.tx.clone())); + } + // Heartbeat ACK + 11 => { + println!("GW: Received Heartbeat ACK"); + } + 2 | 3 | 4 | 6 | 8 => {panic!("Received Gateway op code that's meant to be sent, not received ({})", gateway_payload.op)} + _ => {panic!("Received Invalid Gateway op code ({})", gateway_payload.op)} + } + + // If we have an active heartbeat thread and we received a seq number we should let it know + if gateway_payload.s.is_some() { + if self.heartbeat_handler.is_some() { + + let heartbeat_communication = HeartbeatThreadCommunication { op: gateway_payload.op, d: gateway_payload.s.unwrap() }; + + self.heartbeat_handler.as_mut().unwrap().tx.send(heartbeat_communication).await.unwrap(); + } + } + } + } + + /// Sends json to the gateway with an opcode + async fn send_json_event(&self, op: u8, to_send: serde_json::Value) { + + let gateway_payload = GatewayPayload { op, d: Some(to_send), s: None, t: None }; + + let payload_json = serde_json::to_string(&gateway_payload).unwrap(); + + let message = tokio_tungstenite::tungstenite::Message::text(payload_json); + + self.websocket.tx.lock().await.send(message).await.unwrap(); + } + + /// Sends an identify event to the gateway + pub async fn send_identify(&self, to_send: GatewayIdentifyPayload) { + + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + println!("GW: Sending Identify.."); + + self.send_json_event(2, to_send_value).await; + } + + /// Sends a resume event to the gateway + pub async fn send_resume(&self, to_send: GatewayResume) { + + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + println!("GW: Sending Resume.."); + + self.send_json_event(6, to_send_value).await; + } + + /// Sends an update presence event to the gateway + pub async fn send_update_presence(&self, to_send: PresenceUpdate) { + + let to_send_value = serde_json::to_value(&to_send).unwrap(); + + self.send_json_event(3, to_send_value).await; } } -/*struct WebSocketConnection { +/** +Handles sending heartbeats to the gateway in another thread +*/ +struct HeartbeatHandler { + /// The heartbeat interval in milliseconds + heartbeat_interval: u128, + tx: Sender, +} + +impl HeartbeatHandler { + pub fn new(heartbeat_interval: u128, websocket_tx: Arc>, tokio_tungstenite::tungstenite::Message>>>) -> HeartbeatHandler { + let (tx, mut rx) = mpsc::channel(32); + + task::spawn(async move { + let mut last_heartbeat: Instant = time::Instant::now(); + let mut last_seq_number: Option = None; + + loop { + + // If we received a seq number update, use that as the last seq number + let hb_communication: Result = rx.try_recv(); + if hb_communication.is_ok() { + last_seq_number = Some(hb_communication.unwrap().d); + } + + if last_heartbeat.elapsed().as_millis() > heartbeat_interval { + + println!("GW: Sending Heartbeat.."); + + let heartbeat = GatewayHeartbeat { + op: 1, + d: last_seq_number + }; + + let heartbeat_json = serde_json::to_string(&heartbeat).unwrap(); + + let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json); + + websocket_tx.lock() + .await + .send(msg) + .await + .unwrap(); + + last_heartbeat = time::Instant::now(); + } + } + }); + + Self { heartbeat_interval, tx } + } +} + +/** +Used to communicate with the main thread. +Either signifies a sequence number update or a received heartbeat ack +*/ +#[derive(Clone, Copy, Debug)] +struct HeartbeatThreadCommunication { + /// An opcode for the communication we received + op: u8, + /// The sequence number we got from discord + d: u64 +} + +struct WebSocketConnection { rx: Arc>>, - tx: Arc>>, + tx: Arc>, tokio_tungstenite::tungstenite::Message>>>, } impl<'a> WebSocketConnection { async fn new(websocket_url: String) -> WebSocketConnection { - let parsed_url = Url::parse(&URLBundle::parse_url(websocket_url.clone())).unwrap(); - if parsed_url.scheme() != "ws" && parsed_url.scheme() != "wss" { - return Err(tokio_tungstenite::tungstenite::Error::Url( - UrlError::UnsupportedUrlScheme, - )); - } - - let (mut channel_write, mut channel_read): ( + let (receive_channel_write, receive_channel_read): ( Sender, Receiver, ) = channel(32); - let shared_channel_write = Arc::new(Mutex::new(channel_write)); - let clone_shared_channel_write = shared_channel_write.clone(); - let shared_channel_read = Arc::new(Mutex::new(channel_read)); - let clone_shared_channel_read = shared_channel_read.clone(); + let shared_receive_channel_read = Arc::new(Mutex::new(receive_channel_read)); - task::spawn(async move { - let (mut ws_stream, _) = match connect_async_tls_with_config( - &websocket_url, - None, - Some(Connector::NativeTls( - TlsConnector::builder().build().unwrap(), - )), - ) - .await - { - Ok(ws_stream) => ws_stream, - Err(_) => return, /*return Err(tokio_tungstenite::tungstenite::Error::Io( - io::ErrorKind::ConnectionAborted.into(), - ))*/ - }; - - let (mut write_tx, mut write_rx) = ws_stream.split(); - - while let Some(msg) = shared_channel_read.lock().await.recv().await { - write_tx.send(msg).await.unwrap(); - } + let (ws_stream, _) = match connect_async_tls_with_config( + &websocket_url, + None, + Some(Connector::NativeTls( + TlsConnector::builder().build().unwrap(), + )), + ) + .await + { + Ok(ws_stream) => ws_stream, + Err(e) => panic!("{:?}", e), }; - Ok(Gateway { - url: websocket_url, - token, - events: Events::default(), - socket: ws_stream, - }) + let (ws_tx, mut ws_rx) = ws_stream.split(); + + task::spawn(async move { + loop { + + // Write received messages to the receive channel + let msg = ws_rx.next().await; + if msg.as_ref().is_some() { + let msg_unwrapped = msg.unwrap().unwrap(); + receive_channel_write + .send(msg_unwrapped) + .await + .unwrap(); + }; + } + }); + + WebSocketConnection { + tx: Arc::new(Mutex::new(ws_tx)), + rx: shared_receive_channel_read, + } } -}*/ +} /** Trait which defines the behaviour of an Observer. An Observer is an object which is subscribed to @@ -236,7 +505,7 @@ mod example { #[tokio::test] async fn test_gateway() { - let _gateway = Gateway::new("ws://localhost:3001/".to_string(), "none".to_string()) + let _gateway = Gateway::new("ws://localhost:3001/".to_string()) .await .unwrap(); } diff --git a/src/instance.rs b/src/instance.rs index 79e2944..ae8388a 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -1,13 +1,13 @@ use crate::api::limits::Limits; use crate::api::types::{InstancePolicies}; use crate::errors::{FieldFormatError, InstanceServerError}; +use crate::gateway::Gateway; use crate::limit::LimitedRequester; use crate::URLBundle; use std::fmt; -#[derive(Debug)] /** The [`Instance`] what you will be using to perform all sorts of actions on the Spacebar server. */ @@ -16,7 +16,7 @@ pub struct Instance { pub instance_info: InstancePolicies, pub requester: LimitedRequester, pub limits: Limits, - //pub gateway: Gateway, + pub gateway: Gateway, } impl Instance { @@ -45,6 +45,7 @@ impl Instance { ), limits: Limits::check_limits(urls.api).await, requester, + gateway: Gateway::new(urls.wss.clone()).await.unwrap(), }; instance.instance_info = match instance.instance_policies_schema().await { Ok(schema) => schema,