Merge with main, but better this time

This commit is contained in:
kozabrada123 2023-07-28 18:01:52 +02:00
commit 2337ed78a6
2 changed files with 58 additions and 60 deletions

View File

@ -100,9 +100,9 @@ custom_error! {
UnknownEncryptionModeError = "Server failed to decrypt data", UnknownEncryptionModeError = "Server failed to decrypt data",
// Errors when initiating a gateway connection // Errors when initiating a gateway connection
CannotConnectError{error: String} = "Cannot connect due to a tungstenite error: {error}", CannotConnect{error: String} = "Cannot connect due to a tungstenite error: {error}",
NonHelloOnInitiateError{opcode: u8} = "Received non hello on initial gateway connection ({opcode}), something is definitely wrong", NonHelloOnInitiate{opcode: u8} = "Received non hello on initial gateway connection ({opcode}), something is definitely wrong",
// Other misc errors // Other misc errors
UnexpectedOpcodeReceivedError{opcode: u8} = "Received an opcode we weren't expecting to receive: {opcode}" UnexpectedOpcodeReceived{opcode: u8} = "Received an opcode we weren't expecting to receive: {opcode}",
} }

View File

@ -1,16 +1,17 @@
use futures_util::stream::{SplitSink, SplitStream}; use futures_util::stream::{SplitSink, SplitStream};
use futures_util::SinkExt; use futures_util::SinkExt;
use futures_util::StreamExt; use futures_util::StreamExt;
use log::{debug, info, trace, warn};
use native_tls::TlsConnector; use native_tls::TlsConnector;
use serde_json::json; use serde_json::json;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio::time;
use tokio::time::Instant; use tokio::time::Instant;
use tokio::time::{self, sleep_until};
use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream}; use tokio_tungstenite::{connect_async_tls_with_config, Connector, WebSocketStream};
@ -138,7 +139,7 @@ impl VoiceGatewayHandle {
pub async fn send_identify(&self, to_send: VoiceIdentify) { pub async fn send_identify(&self, to_send: VoiceIdentify) {
let to_send_value = serde_json::to_value(&to_send).unwrap(); let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("VGW: Sending Identify.."); trace!("VGW: Sending Identify..");
self.send_json(VOICE_IDENTIFY, to_send_value).await; self.send_json(VOICE_IDENTIFY, to_send_value).await;
} }
@ -147,7 +148,7 @@ impl VoiceGatewayHandle {
pub async fn send_select_protocol(&self, to_send: SelectProtocol) { pub async fn send_select_protocol(&self, to_send: SelectProtocol) {
let to_send_value = serde_json::to_value(&to_send).unwrap(); let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("VGW: Sending Select Protocol"); trace!("VGW: Sending Select Protocol");
self.send_json(VOICE_SELECT_PROTOCOL, to_send_value).await; self.send_json(VOICE_SELECT_PROTOCOL, to_send_value).await;
} }
@ -156,7 +157,7 @@ impl VoiceGatewayHandle {
pub async fn send_speaking(&self, to_send: Speaking) { pub async fn send_speaking(&self, to_send: Speaking) {
let to_send_value = serde_json::to_value(&to_send).unwrap(); let to_send_value = serde_json::to_value(&to_send).unwrap();
println!("VGW: Sending Speaking"); trace!("VGW: Sending Speaking");
self.send_json(VOICE_SPEAKING, to_send_value).await; self.send_json(VOICE_SPEAKING, to_send_value).await;
} }
@ -165,7 +166,7 @@ impl VoiceGatewayHandle {
pub async fn send_voice_backend_version_request(&self) { pub async fn send_voice_backend_version_request(&self) {
let data_empty_object = json!("{}"); let data_empty_object = json!("{}");
println!("VGW: Requesting voice backend version"); trace!("VGW: Requesting voice backend version");
self.send_json(VOICE_BACKEND_VERSION, data_empty_object) self.send_json(VOICE_BACKEND_VERSION, data_empty_object)
.await; .await;
@ -180,9 +181,9 @@ impl VoiceGatewayHandle {
} }
} }
pub struct VoiceGateway { pub struct VoiceGateway {
pub events: Arc<Mutex<voice_events::VoiceEvents>>, events: Arc<Mutex<voice_events::VoiceEvents>>,
heartbeat_handler: VoiceHeartbeatHandler, heartbeat_handler: VoiceHeartbeatHandler,
pub websocket_send: Arc< websocket_send: Arc<
Mutex< Mutex<
SplitSink< SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>, WebSocketStream<MaybeTlsStream<TcpStream>>,
@ -190,7 +191,7 @@ pub struct VoiceGateway {
>, >,
>, >,
>, >,
pub websocket_receive: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>, websocket_receive: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
kill_send: tokio::sync::broadcast::Sender<()>, kill_send: tokio::sync::broadcast::Sender<()>,
} }
@ -212,7 +213,7 @@ impl VoiceGateway {
{ {
Ok(websocket_stream) => websocket_stream, Ok(websocket_stream) => websocket_stream,
Err(e) => { Err(e) => {
return Err(VoiceGatewayError::CannotConnectError { return Err(VoiceGatewayError::CannotConnect {
error: e.to_string(), error: e.to_string(),
}) })
} }
@ -232,12 +233,12 @@ impl VoiceGateway {
serde_json::from_str(msg.to_text().unwrap()).unwrap(); serde_json::from_str(msg.to_text().unwrap()).unwrap();
if gateway_payload.op_code != VOICE_HELLO { if gateway_payload.op_code != VOICE_HELLO {
return Err(VoiceGatewayError::NonHelloOnInitiateError { return Err(VoiceGatewayError::NonHelloOnInitiate {
opcode: gateway_payload.op_code, opcode: gateway_payload.op_code,
}); });
} }
println!("VGW: Received Hello"); info!("VGW: Received Hello");
// The hello data is the same on voice and normal gateway // The hello data is the same on voice and normal gateway
let gateway_hello: types::HelloData = let gateway_hello: types::HelloData =
@ -249,7 +250,7 @@ impl VoiceGateway {
let mut gateway = VoiceGateway { let mut gateway = VoiceGateway {
events: shared_events.clone(), events: shared_events.clone(),
heartbeat_handler: VoiceHeartbeatHandler::new( heartbeat_handler: VoiceHeartbeatHandler::new(
gateway_hello.heartbeat_interval, Duration::from_millis(gateway_hello.heartbeat_interval),
1, // to:do actually compute nonce 1, // to:do actually compute nonce
shared_websocket_send.clone(), shared_websocket_send.clone(),
kill_send.subscribe(), kill_send.subscribe(),
@ -287,7 +288,7 @@ impl VoiceGateway {
} }
// We couldn't receive the next message or it was an error, something is wrong with the websocket, close // We couldn't receive the next message or it was an error, something is wrong with the websocket, close
println!("VGW: Websocket is broken, stopping gateway"); warn!("VGW: Websocket is broken, stopping gateway");
break; break;
} }
} }
@ -321,7 +322,7 @@ impl VoiceGateway {
} }
if !msg.is_error() && !msg.is_payload() { if !msg.is_error() && !msg.is_payload() {
println!( warn!(
"Message unrecognised: {:?}, please open an issue on the chorus github", "Message unrecognised: {:?}, please open an issue on the chorus github",
msg.message.to_string() msg.message.to_string()
); );
@ -330,12 +331,10 @@ impl VoiceGateway {
// To:do: handle errors in a good way, maybe observers like events? // To:do: handle errors in a good way, maybe observers like events?
if msg.is_error() { if msg.is_error() {
println!("VGW: Received error, connection will close.."); warn!("VGW: Received error, connection will close..");
let _error = msg.error(); let _error = msg.error();
{}
self.close().await; self.close().await;
return; return;
} }
@ -347,7 +346,7 @@ impl VoiceGateway {
let event = &mut self.events.lock().await.voice_ready; let event = &mut self.events.lock().await.voice_ready;
let result = VoiceGateway::handle_event(gateway_payload.data.get(), event).await; let result = VoiceGateway::handle_event(gateway_payload.data.get(), event).await;
if result.is_err() { if result.is_err() {
println!("Failed to parse VOICE_READY ({})", result.err().unwrap()); warn!("Failed to parse VOICE_READY ({})", result.err().unwrap());
return; return;
} }
} }
@ -355,7 +354,7 @@ impl VoiceGateway {
let event = &mut self.events.lock().await.session_description; let event = &mut self.events.lock().await.session_description;
let result = VoiceGateway::handle_event(gateway_payload.data.get(), event).await; let result = VoiceGateway::handle_event(gateway_payload.data.get(), event).await;
if result.is_err() { if result.is_err() {
println!( warn!(
"Failed to parse VOICE_SELECT_PROTOCOL ({})", "Failed to parse VOICE_SELECT_PROTOCOL ({})",
result.err().unwrap() result.err().unwrap()
); );
@ -365,7 +364,7 @@ impl VoiceGateway {
// We received a heartbeat from the server // We received a heartbeat from the server
// "Discord may send the app a Heartbeat (opcode 1) event, in which case the app should send a Heartbeat event immediately." // "Discord may send the app a Heartbeat (opcode 1) event, in which case the app should send a Heartbeat event immediately."
VOICE_HEARTBEAT => { VOICE_HEARTBEAT => {
println!("VGW: Received Heartbeat // Heartbeat Request"); trace!("VGW: Received Heartbeat // Heartbeat Request");
// Tell the heartbeat handler it should send a heartbeat right away // Tell the heartbeat handler it should send a heartbeat right away
let heartbeat_communication = VoiceHeartbeatThreadCommunication { let heartbeat_communication = VoiceHeartbeatThreadCommunication {
@ -380,7 +379,7 @@ impl VoiceGateway {
.unwrap(); .unwrap();
} }
VOICE_HEARTBEAT_ACK => { VOICE_HEARTBEAT_ACK => {
println!("VGW: Received Heartbeat ACK"); debug!("VGW: Received Heartbeat ACK");
// Tell the heartbeat handler we received an ack // Tell the heartbeat handler we received an ack
@ -396,13 +395,13 @@ impl VoiceGateway {
.unwrap(); .unwrap();
} }
VOICE_IDENTIFY | VOICE_SELECT_PROTOCOL | VOICE_RESUME => { VOICE_IDENTIFY | VOICE_SELECT_PROTOCOL | VOICE_RESUME => {
let error = VoiceGatewayError::UnexpectedOpcodeReceivedError { let error = VoiceGatewayError::UnexpectedOpcodeReceived {
opcode: gateway_payload.op_code, opcode: gateway_payload.op_code,
}; };
Err::<(), VoiceGatewayError>(error).unwrap(); Err::<(), VoiceGatewayError>(error).unwrap();
} }
_ => { _ => {
println!("Received unrecognized voice gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code); warn!("Received unrecognized voice gateway op code ({})! Please open an issue on the chorus github so we can implement it", gateway_payload.op_code);
} }
} }
} }
@ -411,7 +410,7 @@ impl VoiceGateway {
/// Handles sending heartbeats to the voice gateway in another thread /// Handles sending heartbeats to the voice gateway in another thread
struct VoiceHeartbeatHandler { struct VoiceHeartbeatHandler {
/// The heartbeat interval in milliseconds /// The heartbeat interval in milliseconds
pub heartbeat_interval: u128, pub heartbeat_interval: Duration,
/// The send channel for the heartbeat thread /// The send channel for the heartbeat thread
pub send: Sender<VoiceHeartbeatThreadCommunication>, pub send: Sender<VoiceHeartbeatThreadCommunication>,
/// The handle of the thread /// The handle of the thread
@ -420,7 +419,7 @@ struct VoiceHeartbeatHandler {
impl VoiceHeartbeatHandler { impl VoiceHeartbeatHandler {
pub fn new( pub fn new(
heartbeat_interval: u128, heartbeat_interval: Duration,
starting_nonce: u64, starting_nonce: u64,
websocket_tx: Arc< websocket_tx: Arc<
Mutex< Mutex<
@ -466,7 +465,7 @@ impl VoiceHeartbeatHandler {
>, >,
>, >,
>, >,
heartbeat_interval: u128, heartbeat_interval: Duration,
starting_nonce: u64, starting_nonce: u64,
mut receive: tokio::sync::mpsc::Receiver<VoiceHeartbeatThreadCommunication>, mut receive: tokio::sync::mpsc::Receiver<VoiceHeartbeatThreadCommunication>,
mut kill_receive: tokio::sync::broadcast::Receiver<()>, mut kill_receive: tokio::sync::broadcast::Receiver<()>,
@ -476,52 +475,51 @@ impl VoiceHeartbeatHandler {
let mut nonce: u64 = starting_nonce; let mut nonce: u64 = starting_nonce;
loop { loop {
let should_shutdown = kill_receive.try_recv().is_ok(); if kill_receive.try_recv().is_ok() {
if should_shutdown { trace!("VGW: Closing heartbeat task");
break; break;
} }
let mut should_send; let timeout = if last_heartbeat_acknowledged {
heartbeat_interval
} else {
// If the server hasn't acknowledged our heartbeat we should resend it
Duration::from_millis(HEARTBEAT_ACK_TIMEOUT)
};
let time_to_send = last_heartbeat_timestamp.elapsed().as_millis() >= heartbeat_interval; let mut should_send = false;
should_send = time_to_send; tokio::select! {
() = sleep_until(last_heartbeat_timestamp + timeout) => {
should_send = true;
}
let received_communication: Result<VoiceHeartbeatThreadCommunication, TryRecvError> =
receive.try_recv();
if received_communication.is_ok() {
let communication = received_communication.unwrap();
// If we received a nonce update, use that nonce now Some(communication) = receive.recv() => {
// If we received a nonce update, use that nonce now
if communication.updated_nonce.is_some() { if communication.updated_nonce.is_some() {
nonce = communication.updated_nonce.unwrap(); nonce = communication.updated_nonce.unwrap();
} }
if communication.op_code.is_some() { if let Some(op_code) = communication.op_code {
match communication.op_code.unwrap() { match op_code {
VOICE_HEARTBEAT => { VOICE_HEARTBEAT => {
// As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately // As per the api docs, if the server sends us a Heartbeat, that means we need to respond with a heartbeat immediately
should_send = true; should_send = true;
}
VOICE_HEARTBEAT_ACK => {
// The server received our heartbeat
last_heartbeat_acknowledged = true;
}
_ => {}
} }
VOICE_HEARTBEAT_ACK => {
// The server received our heartbeat
last_heartbeat_acknowledged = true;
}
_ => {}
} }
} }
} }
// If the server hasn't acknowledged our heartbeat we should resend it
if !last_heartbeat_acknowledged
&& last_heartbeat_timestamp.elapsed().as_millis() > HEARTBEAT_ACK_TIMEOUT
{
should_send = true;
println!("VGW: Timed out waiting for a heartbeat ack, resending");
}
if should_send { if should_send {
println!("VGW: Sending Heartbeat.."); trace!("VGW: Sending Heartbeat..");
let heartbeat = VoiceGatewaySendPayload { let heartbeat = VoiceGatewaySendPayload {
op_code: VOICE_HEARTBEAT, op_code: VOICE_HEARTBEAT,
@ -535,7 +533,7 @@ impl VoiceHeartbeatHandler {
let send_result = websocket_tx.lock().await.send(msg).await; let send_result = websocket_tx.lock().await.send(msg).await;
if send_result.is_err() { if send_result.is_err() {
// We couldn't send, the websocket is broken // We couldn't send, the websocket is broken
println!("VGW: Couldnt send heartbeat, websocket seems broken"); warn!("VGW: Couldnt send heartbeat, websocket seems broken");
break; break;
} }