Merge branch 'feature/gateway-observer' into main

This commit is contained in:
kozabrada123 2023-05-09 18:35:53 +00:00 committed by GitHub
commit 6ce7396275
6 changed files with 348 additions and 60 deletions

View File

@ -8,7 +8,7 @@ pub mod login {
use crate::errors::InstanceServerError; use crate::errors::InstanceServerError;
use crate::instance::Instance; use crate::instance::Instance;
impl Instance { impl<'a> Instance<'a> {
pub async fn login_account( pub async fn login_account(
&mut self, &mut self,
login_schema: &LoginSchema, login_schema: &LoginSchema,

View File

@ -8,7 +8,7 @@ pub mod register {
instance::Instance, instance::Instance,
}; };
impl Instance { impl<'a> Instance<'a> {
/** /**
Registers a new user on the Spacebar server. Registers a new user on the Spacebar server.
# Arguments # Arguments

View File

@ -53,4 +53,4 @@ mod instance_policies_schema_test {
let schema = test_instance.instance_policies_schema().await.unwrap(); let schema = test_instance.instance_policies_schema().await.unwrap();
} }
} }

View File

@ -132,6 +132,12 @@ pub struct Error {
pub code: String, pub code: String,
} }
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct UnavailableGuild {
id: String,
unavailable: bool
}
#[derive(Serialize, Deserialize, Debug, Default)] #[derive(Serialize, Deserialize, Debug, Default)]
pub struct UserObject { pub struct UserObject {
pub id: String, pub id: String,
@ -717,7 +723,7 @@ pub struct PresenceUpdate {
since: Option<i64>, since: Option<i64>,
activities: Vec<Activity>, activities: Vec<Activity>,
status: String, status: String,
afk: bool, afk: Option<bool>,
} }
impl WebSocketEvent for PresenceUpdate {} impl WebSocketEvent for PresenceUpdate {}
@ -785,6 +791,18 @@ pub struct GatewayResume {
impl WebSocketEvent for GatewayResume {} impl WebSocketEvent for GatewayResume {}
#[derive(Debug, Deserialize, Serialize, Default)]
pub struct GatewayReady {
pub v: u8,
pub user: UserObject,
pub guilds: Vec<UnavailableGuild>,
pub session_id: String,
pub resume_gateway_url: Option<String>,
pub shard: Option<(u64, u64)>,
}
impl WebSocketEvent for GatewayReady {}
#[derive(Debug, Default, Deserialize, Serialize)] #[derive(Debug, Default, Deserialize, Serialize)]
pub struct GatewayHello { pub struct GatewayHello {
pub op: i32, pub op: i32,
@ -795,7 +813,7 @@ impl WebSocketEvent for GatewayHello {}
#[derive(Debug, Default, Deserialize, Serialize)] #[derive(Debug, Default, Deserialize, Serialize)]
pub struct HelloData { pub struct HelloData {
pub heartbeat_interval: i32, pub heartbeat_interval: u128,
} }
impl WebSocketEvent for HelloData {} impl WebSocketEvent for HelloData {}
@ -803,7 +821,7 @@ impl WebSocketEvent for HelloData {}
#[derive(Debug, Default, Deserialize, Serialize)] #[derive(Debug, Default, Deserialize, Serialize)]
pub struct GatewayHeartbeat { pub struct GatewayHeartbeat {
pub op: u8, pub op: u8,
pub d: u64, pub d: Option<u64>,
} }
impl WebSocketEvent for GatewayHeartbeat {} impl WebSocketEvent for GatewayHeartbeat {}
@ -817,9 +835,9 @@ impl WebSocketEvent for GatewayHeartbeatAck {}
#[derive(Debug, Default, Deserialize, Serialize)] #[derive(Debug, Default, Deserialize, Serialize)]
pub struct GatewayPayload { pub struct GatewayPayload {
pub op: i32, pub op: u8,
pub d: Option<String>, pub d: Option<serde_json::Value>,
pub s: Option<i64>, pub s: Option<u64>,
pub t: Option<String>, pub t: Option<String>,
} }

View File

@ -1,7 +1,22 @@
use std::sync::Arc;
use crate::api::types::*; use crate::api::types::*;
use crate::api::WebSocketEvent; use crate::api::WebSocketEvent;
use crate::errors::ObserverError; use crate::errors::ObserverError;
use crate::gateway::events::Events; use crate::gateway::events::Events;
use futures_util::SinkExt;
use futures_util::StreamExt;
use futures_util::stream::SplitSink;
use native_tls::TlsConnector;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::Mutex;
use tokio::task;
use tokio::time;
use tokio::time::Instant;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::{WebSocketStream, Connector, connect_async_tls_with_config};
/** /**
Represents a Gateway connection. A Gateway connection will create observable Represents a Gateway connection. A Gateway connection will create observable
@ -10,78 +25,332 @@ implemented [Types] with the trait [`WebSocketEvent`]
*/ */
pub struct Gateway<'a> { pub struct Gateway<'a> {
pub url: String, pub url: String,
pub token: String,
pub events: Events<'a>, pub events: Events<'a>,
websocket: WebSocketConnection,
heartbeat_handler: Option<HeartbeatHandler>
} }
impl<'a> Gateway<'a> { impl<'a> Gateway<'a> {
pub async fn new( pub async fn new(
websocket_url: String, websocket_url: String,
token: String,
) -> Result<Gateway<'a>, tokio_tungstenite::tungstenite::Error> { ) -> Result<Gateway<'a>, tokio_tungstenite::tungstenite::Error> {
Ok(Gateway { return Ok(Gateway {
url: websocket_url, url: websocket_url.clone(),
token,
events: Events::default(), events: Events::default(),
}) websocket: WebSocketConnection::new(websocket_url).await,
heartbeat_handler: None,
});
}
/// This function reads all messages from the gateway's websocket and updates its events along with the events' observers
pub async fn update_events(&mut self) {
while let Ok(msg) = self.websocket.rx.lock().await.try_recv() {
if msg.to_string() == String::new() {
continue;
}
let gateway_payload: GatewayPayload = serde_json::from_str(msg.to_text().unwrap()).unwrap();
// See https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes
match gateway_payload.op {
// Dispatch
// An event was dispatched, we need to look at the gateway event name t
0 => {
let gateway_payload_t = gateway_payload.t.unwrap();
println!("GW: Received {}..", gateway_payload_t);
// See https://discord.com/developers/docs/topics/gateway-events#receive-events
match gateway_payload_t.as_str() {
"READY" => {
let data: GatewayReady = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
}
"APPLICATION_COMMAND_PERMISSIONS_UPDATE" => {}
"AUTO_MODERATION_RULE_CREATE" => {}
"AUTO_MODERATION_RULE_UPDATE" => {}
"AUTO_MODERATION_RULE_DELETE" => {}
"AUTO_MODERATION_ACTION_EXECUTION" => {}
"CHANNEL_CREATE" => {}
"CHANNEL_UPDATE" => {}
"CHANNEL_DELETE" => {}
"CHANNEL_PINS_UPDATE" => {}
"THREAD_CREATE" => {}
"THREAD_UPDATE" => {}
"THREAD_DELETE" => {}
"THREAD_LIST_SYNC" => {}
"THREAD_MEMBER_UPDATE" => {}
"THREAD_MEMBERS_UPDATE" => {}
"GUILD_CREATE" => {}
"GUILD_UPDATE" => {}
"GUILD_DELETE" => {}
"GUILD_AUDIT_LOG_ENTRY_CREATE" => {}
"GUILD_BAN_ADD" => {}
"GUILD_BAN_REMOVE" => {}
"GUILD_EMOJIS_UPDATE" => {}
"GUILD_STICKERS_UPDATE" => {}
"GUILD_INTEGRATIONS_UPDATE" => {}
"GUILD_MEMBER_ADD" => {}
"GUILD_MEMBER_REMOVE" => {}
"GUILD_MEMBER_UPDATE" => {}
"GUILD_MEMBERS_CHUNK" => {}
"GUILD_ROLE_CREATE" => {}
"GUILD_ROLE_UPDATE" => {}
"GUILD_ROLE_DELETE" => {}
"GUILD_SCHEDULED_EVENT_CREATE" => {}
"GUILD_SCHEDULED_EVENT_UPDATE" => {}
"GUILD_SCHEDULED_EVENT_DELETE" => {}
"GUILD_SCHEDULED_EVENT_USER_ADD" => {}
"GUILD_SCHEDULED_EVENT_USER_REMOVE" => {}
"INTEGRATION_CREATE" => {}
"INTEGRATION_UPDATE" => {}
"INTEGRATION_DELETE" => {}
"INTERACTION_CREATE" => {}
"INVITE_CREATE" => {}
"INVITE_DELETE" => {}
"MESSAGE_CREATE" => {
let new_data: MessageCreate = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.message.create.update_data(new_data);
}
"MESSAGE_UPDATE" => {
let new_data: MessageUpdate = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.message.update.update_data(new_data);
}
"MESSAGE_DELETE" => {
let new_data: MessageDelete = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.message.delete.update_data(new_data);
}
"MESSAGE_DELETE_BULK" => {
let new_data: MessageDeleteBulk = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.message.delete_bulk.update_data(new_data);
}
"MESSAGE_REACTION_ADD" => {
let new_data: MessageReactionAdd = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.message.reaction_add.update_data(new_data);
}
"MESSAGE_REACTION_REMOVE" => {
let new_data: MessageReactionRemove = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.message.reaction_remove.update_data(new_data);
}
"MESSAGE_REACTION_REMOVE_ALL" => {
let new_data: MessageReactionRemoveAll = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.message.reaction_remove_all.update_data(new_data);
}
"MESSAGE_REACTION_REMOVE_EMOJI" => {
let new_data: MessageReactionRemoveEmoji= serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.message.reaction_remove_emoji.update_data(new_data);
}
"PRESENCE_UPDATE" => {
let new_data: PresenceUpdate = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.user.presence_update.update_data(new_data);
}
"STAGE_INSTANCE_CREATE" => {}
"STAGE_INSTANCE_UPDATE" => {}
"STAGE_INSTANCE_DELETE" => {}
// Not documented in discord docs, I assume this isnt for bots / apps but is for users?
"SESSIONS_REPLACE" => {}
"TYPING_START" => {
let new_data: TypingStartEvent = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.events.user.typing_start_event.update_data(new_data);
}
"USER_UPDATE" => {}
"VOICE_STATE_UPDATE" => {}
"VOICE_SERVER_UPDATE" => {}
"WEBHOOKS_UPDATE" => {}
_ => {panic!("Invalid gateway event ({})", &gateway_payload_t)}
}
}
// Heartbeat
// We received a heartbeat from the server
1 => {}
// Reconnect
7 => {todo!()}
// Invalid Session
9 => {todo!()}
// Hello
// Starts our heartbeat
10 => {
println!("GW: Received Hello");
let gateway_hello: HelloData = serde_json::from_value(gateway_payload.d.unwrap()).unwrap();
self.heartbeat_handler = Some(HeartbeatHandler::new(gateway_hello.heartbeat_interval, self.websocket.tx.clone()));
}
// Heartbeat ACK
11 => {
println!("GW: Received Heartbeat ACK");
}
2 | 3 | 4 | 6 | 8 => {panic!("Received Gateway op code that's meant to be sent, not received ({})", gateway_payload.op)}
_ => {panic!("Received Invalid Gateway op code ({})", gateway_payload.op)}
}
// If we have an active heartbeat thread and we received a seq number we should let it know
if gateway_payload.s.is_some() {
if self.heartbeat_handler.is_some() {
let heartbeat_communication = HeartbeatThreadCommunication { op: gateway_payload.op, d: gateway_payload.s.unwrap() };
self.heartbeat_handler.as_mut().unwrap().tx.send(heartbeat_communication).await.unwrap();
}
}
}
}
/// Sends json to the gateway with an opcode
async fn send_json_event(&self, op: u8, to_send: serde_json::Value) {
let gateway_payload = GatewayPayload { op, d: Some(to_send), s: None, t: None };
let payload_json = serde_json::to_string(&gateway_payload).unwrap();
let message = tokio_tungstenite::tungstenite::Message::text(payload_json);
self.websocket.tx.lock().await.send(message).await.unwrap();
}
/// Sends an identify event to the gateway
pub async fn send_identify(&self, to_send: GatewayIdentifyPayload) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Identify..");
self.send_json_event(2, to_send_value).await;
}
/// Sends a resume event to the gateway
pub async fn send_resume(&self, to_send: GatewayResume) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("GW: Sending Resume..");
self.send_json_event(6, to_send_value).await;
}
/// Sends an update presence event to the gateway
pub async fn send_update_presence(&self, to_send: PresenceUpdate) {
let to_send_value = serde_json::to_value(&to_send).unwrap();
self.send_json_event(3, to_send_value).await;
} }
} }
/*struct WebSocketConnection { /**
Handles sending heartbeats to the gateway in another thread
*/
struct HeartbeatHandler {
/// The heartbeat interval in milliseconds
heartbeat_interval: u128,
tx: Sender<HeartbeatThreadCommunication>,
}
impl HeartbeatHandler {
pub fn new(heartbeat_interval: u128, websocket_tx: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Message>>>) -> HeartbeatHandler {
let (tx, mut rx) = mpsc::channel(32);
task::spawn(async move {
let mut last_heartbeat: Instant = time::Instant::now();
let mut last_seq_number: Option<u64> = None;
loop {
// If we received a seq number update, use that as the last seq number
let hb_communication: Result<HeartbeatThreadCommunication, TryRecvError> = rx.try_recv();
if hb_communication.is_ok() {
last_seq_number = Some(hb_communication.unwrap().d);
}
if last_heartbeat.elapsed().as_millis() > heartbeat_interval {
println!("GW: Sending Heartbeat..");
let heartbeat = GatewayHeartbeat {
op: 1,
d: last_seq_number
};
let heartbeat_json = serde_json::to_string(&heartbeat).unwrap();
let msg = tokio_tungstenite::tungstenite::Message::text(heartbeat_json);
websocket_tx.lock()
.await
.send(msg)
.await
.unwrap();
last_heartbeat = time::Instant::now();
}
}
});
Self { heartbeat_interval, tx }
}
}
/**
Used to communicate with the main thread.
Either signifies a sequence number update or a received heartbeat ack
*/
#[derive(Clone, Copy, Debug)]
struct HeartbeatThreadCommunication {
/// An opcode for the communication we received
op: u8,
/// The sequence number we got from discord
d: u64
}
struct WebSocketConnection {
rx: Arc<Mutex<Receiver<tokio_tungstenite::tungstenite::Message>>>, rx: Arc<Mutex<Receiver<tokio_tungstenite::tungstenite::Message>>>,
tx: Arc<Mutex<Sender<tokio_tungstenite::tungstenite::Message>>>, tx: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Message>>>,
} }
impl<'a> WebSocketConnection { impl<'a> WebSocketConnection {
async fn new(websocket_url: String) -> WebSocketConnection { async fn new(websocket_url: String) -> WebSocketConnection {
let parsed_url = Url::parse(&URLBundle::parse_url(websocket_url.clone())).unwrap(); let (receive_channel_write, receive_channel_read): (
if parsed_url.scheme() != "ws" && parsed_url.scheme() != "wss" {
return Err(tokio_tungstenite::tungstenite::Error::Url(
UrlError::UnsupportedUrlScheme,
));
}
let (mut channel_write, mut channel_read): (
Sender<tokio_tungstenite::tungstenite::Message>, Sender<tokio_tungstenite::tungstenite::Message>,
Receiver<tokio_tungstenite::tungstenite::Message>, Receiver<tokio_tungstenite::tungstenite::Message>,
) = channel(32); ) = channel(32);
let shared_channel_write = Arc::new(Mutex::new(channel_write)); let shared_receive_channel_read = Arc::new(Mutex::new(receive_channel_read));
let clone_shared_channel_write = shared_channel_write.clone();
let shared_channel_read = Arc::new(Mutex::new(channel_read));
let clone_shared_channel_read = shared_channel_read.clone();
task::spawn(async move { let (ws_stream, _) = match connect_async_tls_with_config(
let (mut ws_stream, _) = match connect_async_tls_with_config( &websocket_url,
&websocket_url, None,
None, Some(Connector::NativeTls(
Some(Connector::NativeTls( TlsConnector::builder().build().unwrap(),
TlsConnector::builder().build().unwrap(), )),
)), )
) .await
.await {
{ Ok(ws_stream) => ws_stream,
Ok(ws_stream) => ws_stream, Err(e) => panic!("{:?}", e),
Err(_) => return, /*return Err(tokio_tungstenite::tungstenite::Error::Io(
io::ErrorKind::ConnectionAborted.into(),
))*/
};
let (mut write_tx, mut write_rx) = ws_stream.split();
while let Some(msg) = shared_channel_read.lock().await.recv().await {
write_tx.send(msg).await.unwrap();
}
}; };
Ok(Gateway { let (ws_tx, mut ws_rx) = ws_stream.split();
url: websocket_url,
token, task::spawn(async move {
events: Events::default(), loop {
socket: ws_stream,
}) // Write received messages to the receive channel
let msg = ws_rx.next().await;
if msg.as_ref().is_some() {
let msg_unwrapped = msg.unwrap().unwrap();
receive_channel_write
.send(msg_unwrapped)
.await
.unwrap();
};
}
});
WebSocketConnection {
tx: Arc::new(Mutex::new(ws_tx)),
rx: shared_receive_channel_read,
}
} }
}*/ }
/** /**
Trait which defines the behaviour of an Observer. An Observer is an object which is subscribed to Trait which defines the behaviour of an Observer. An Observer is an object which is subscribed to
@ -236,7 +505,7 @@ mod example {
#[tokio::test] #[tokio::test]
async fn test_gateway() { async fn test_gateway() {
let _gateway = Gateway::new("ws://localhost:3001/".to_string(), "none".to_string()) let _gateway = Gateway::new("ws://localhost:3001/".to_string())
.await .await
.unwrap(); .unwrap();
} }

View File

@ -1,13 +1,13 @@
use crate::api::limits::Limits; use crate::api::limits::Limits;
use crate::api::types::{InstancePolicies}; use crate::api::types::{InstancePolicies};
use crate::errors::{FieldFormatError, InstanceServerError}; use crate::errors::{FieldFormatError, InstanceServerError};
use crate::gateway::Gateway;
use crate::limit::LimitedRequester; use crate::limit::LimitedRequester;
use crate::URLBundle; use crate::URLBundle;
use std::fmt; use std::fmt;
#[derive(Debug)]
/** /**
The [`Instance`] what you will be using to perform all sorts of actions on the Spacebar server. The [`Instance`] what you will be using to perform all sorts of actions on the Spacebar server.
*/ */
@ -16,7 +16,7 @@ pub struct Instance {
pub instance_info: InstancePolicies, pub instance_info: InstancePolicies,
pub requester: LimitedRequester, pub requester: LimitedRequester,
pub limits: Limits, pub limits: Limits,
//pub gateway: Gateway, pub gateway: Gateway,
} }
impl Instance { impl Instance {
@ -45,6 +45,7 @@ impl Instance {
), ),
limits: Limits::check_limits(urls.api).await, limits: Limits::check_limits(urls.api).await,
requester, requester,
gateway: Gateway::new(urls.wss.clone()).await.unwrap(),
}; };
instance.instance_info = match instance.instance_policies_schema().await { instance.instance_info = match instance.instance_policies_schema().await {
Ok(schema) => schema, Ok(schema) => schema,