Reuse gateway backends, don't duplicate them for voice gateway (#493)

This commit is contained in:
kozabrada123 2024-04-30 16:17:47 +02:00 committed by GitHub
parent b2fd3e18cc
commit aac31726ec
10 changed files with 50 additions and 142 deletions

View File

@ -93,7 +93,7 @@ custom_error! {
DisallowedIntents = "You sent a disallowed intent. You may have tried to specify an intent that you have not enabled or are not approved for", DisallowedIntents = "You sent a disallowed intent. You may have tried to specify an intent that you have not enabled or are not approved for",
// Errors when initiating a gateway connection // Errors when initiating a gateway connection
CannotConnect{error: String} = "Cannot connect due to a tungstenite error: {error}", CannotConnect{error: String} = "Cannot connect due to a websocket error: {error}",
NonHelloOnInitiate{opcode: u8} = "Received non hello on initial gateway connection ({opcode}), something is definitely wrong", NonHelloOnInitiate{opcode: u8} = "Received non hello on initial gateway connection ({opcode}), something is definitely wrong",
// Other misc errors // Other misc errors
@ -124,7 +124,7 @@ custom_error! {
UnknownEncryptionMode = "Server failed to decrypt data", UnknownEncryptionMode = "Server failed to decrypt data",
// Errors when initiating a gateway connection // Errors when initiating a gateway connection
CannotConnect{error: String} = "Cannot connect due to a tungstenite error: {error}", CannotConnect{error: String} = "Cannot connect due to a websocket error: {error}",
NonHelloOnInitiate{opcode: u8} = "Received non hello on initial gateway connection ({opcode}), something is definitely wrong", NonHelloOnInitiate{opcode: u8} = "Received non hello on initial gateway connection ({opcode}), something is definitely wrong",
// Other misc errors // Other misc errors

View File

@ -2,6 +2,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this // License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
use custom_error::custom_error;
use futures_util::{ use futures_util::{
stream::{SplitSink, SplitStream}, stream::{SplitSink, SplitStream},
StreamExt, StreamExt,
@ -11,7 +12,6 @@ use tokio_tungstenite::{
connect_async_tls_with_config, tungstenite, Connector, MaybeTlsStream, WebSocketStream, connect_async_tls_with_config, tungstenite, Connector, MaybeTlsStream, WebSocketStream,
}; };
use crate::errors::GatewayError;
use crate::gateway::GatewayMessage; use crate::gateway::GatewayMessage;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -22,18 +22,22 @@ pub type TungsteniteSink =
SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>; SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>;
pub type TungsteniteStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>; pub type TungsteniteStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
custom_error! {
pub TungsteniteBackendError
FailedToLoadCerts{error: std::io::Error} = "failed to load platform native certs: {error}",
TungsteniteError{error: tungstenite::error::Error} = "encountered a tungstenite error: {error}",
}
impl TungsteniteBackend { impl TungsteniteBackend {
pub async fn connect( pub async fn connect(
websocket_url: &str, websocket_url: &str,
) -> Result<(TungsteniteSink, TungsteniteStream), crate::errors::GatewayError> { ) -> Result<(TungsteniteSink, TungsteniteStream), TungsteniteBackendError> {
let mut roots = rustls::RootCertStore::empty(); let mut roots = rustls::RootCertStore::empty();
let certs = rustls_native_certs::load_native_certs(); let certs = rustls_native_certs::load_native_certs();
if let Err(e) = certs { if let Err(e) = certs {
log::error!("Failed to load platform native certs! {:?}", e); log::error!("Failed to load platform native certs! {:?}", e);
return Err(GatewayError::CannotConnect { return Err(TungsteniteBackendError::FailedToLoadCerts { error: e });
error: format!("{:?}", e),
});
} }
for cert in certs.unwrap() { for cert in certs.unwrap() {
@ -55,8 +59,8 @@ impl TungsteniteBackend {
{ {
Ok(websocket_stream) => websocket_stream, Ok(websocket_stream) => websocket_stream,
Err(e) => { Err(e) => {
return Err(GatewayError::CannotConnect { return Err(TungsteniteBackendError::TungsteniteError {
error: e.to_string(), error: e,
}) })
} }
}; };

View File

@ -9,7 +9,6 @@ use futures_util::{
use ws_stream_wasm::*; use ws_stream_wasm::*;
use crate::errors::GatewayError;
use crate::gateway::GatewayMessage; use crate::gateway::GatewayMessage;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -22,13 +21,8 @@ pub type WasmStream = SplitStream<WsStream>;
impl WasmBackend { impl WasmBackend {
pub async fn connect( pub async fn connect(
websocket_url: &str, websocket_url: &str,
) -> Result<(WasmSink, WasmStream), crate::errors::GatewayError> { ) -> Result<(WasmSink, WasmStream), ws_stream_wasm::WsErr> {
let (_, websocket_stream) = match WsMeta::connect(websocket_url, None).await { let (_, websocket_stream) = WsMeta::connect(websocket_url, None).await?;
Ok(stream) => Ok(stream),
Err(e) => Err(GatewayError::CannotConnect {
error: e.to_string(),
}),
}?;
Ok(websocket_stream.split()) Ok(websocket_stream.split())
} }

View File

@ -35,7 +35,14 @@ impl Gateway {
#[allow(clippy::new_ret_no_self)] #[allow(clippy::new_ret_no_self)]
pub async fn spawn(websocket_url: String) -> Result<GatewayHandle, GatewayError> { pub async fn spawn(websocket_url: String) -> Result<GatewayHandle, GatewayError> {
let (websocket_send, mut websocket_receive) = let (websocket_send, mut websocket_receive) =
WebSocketBackend::connect(&websocket_url).await?; match WebSocketBackend::connect(&websocket_url).await {
Ok(streams) => streams,
Err(e) => {
return Err(GatewayError::CannotConnect {
error: format!("{:?}", e),
})
}
};
let shared_websocket_send = Arc::new(Mutex::new(websocket_send)); let shared_websocket_send = Arc::new(Mutex::new(websocket_send));

View File

@ -4,24 +4,7 @@
#[cfg(all(not(target_arch = "wasm32"), feature = "voice_gateway"))] #[cfg(all(not(target_arch = "wasm32"), feature = "voice_gateway"))]
pub mod tungstenite; pub mod tungstenite;
#[cfg(all(not(target_arch = "wasm32"), feature = "voice_gateway"))]
pub use tungstenite::*;
#[cfg(all(target_arch = "wasm32", feature = "voice_gateway"))] #[cfg(all(target_arch = "wasm32", feature = "voice_gateway"))]
pub mod wasm; pub mod wasm;
#[cfg(all(target_arch = "wasm32", feature = "voice_gateway"))]
pub use wasm::*;
#[cfg(all(not(target_arch = "wasm32"), feature = "voice_gateway"))]
pub type Sink = tungstenite::TungsteniteSink;
#[cfg(all(not(target_arch = "wasm32"), feature = "voice_gateway"))]
pub type Stream = tungstenite::TungsteniteStream;
#[cfg(all(not(target_arch = "wasm32"), feature = "voice_gateway"))]
pub type WebSocketBackend = tungstenite::TungsteniteBackend;
#[cfg(all(target_arch = "wasm32", feature = "voice_gateway"))]
pub type Sink = wasm::WasmSink;
#[cfg(all(target_arch = "wasm32", feature = "voice_gateway"))]
pub type Stream = wasm::WasmStream;
#[cfg(all(target_arch = "wasm32", feature = "voice_gateway"))]
pub type WebSocketBackend = wasm::WasmBackend;

View File

@ -2,76 +2,16 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this // License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
use futures_util::{ use crate::voice::gateway::VoiceGatewayMessage;
stream::{SplitSink, SplitStream},
StreamExt,
};
use tokio::net::TcpStream;
use tokio_tungstenite::{
connect_async_tls_with_config, tungstenite, Connector, MaybeTlsStream, WebSocketStream,
};
use crate::{errors::VoiceGatewayError, voice::gateway::VoiceGatewayMessage}; impl From<VoiceGatewayMessage> for tokio_tungstenite::tungstenite::Message {
#[derive(Debug, Clone)]
pub struct TungsteniteBackend;
// These could be made into inherent associated types when that's stabilized
pub type TungsteniteSink =
SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, tungstenite::Message>;
pub type TungsteniteStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
impl TungsteniteBackend {
pub async fn connect(
websocket_url: &str,
) -> Result<(TungsteniteSink, TungsteniteStream), crate::errors::VoiceGatewayError> {
let mut roots = rustls::RootCertStore::empty();
let certs = rustls_native_certs::load_native_certs();
if let Err(e) = certs {
log::error!("Failed to load platform native certs! {:?}", e);
return Err(VoiceGatewayError::CannotConnect {
error: format!("{:?}", e),
});
}
for cert in certs.unwrap() {
roots.add(&rustls::Certificate(cert.0)).unwrap();
}
let (websocket_stream, _) = match connect_async_tls_with_config(
websocket_url,
None,
false,
Some(Connector::Rustls(
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth()
.into(),
)),
)
.await
{
Ok(websocket_stream) => websocket_stream,
Err(e) => {
return Err(VoiceGatewayError::CannotConnect {
error: e.to_string(),
})
}
};
Ok(websocket_stream.split())
}
}
impl From<VoiceGatewayMessage> for tungstenite::Message {
fn from(message: VoiceGatewayMessage) -> Self { fn from(message: VoiceGatewayMessage) -> Self {
Self::Text(message.0) Self::Text(message.0)
} }
} }
impl From<tungstenite::Message> for VoiceGatewayMessage { impl From<tokio_tungstenite::tungstenite::Message> for VoiceGatewayMessage {
fn from(value: tungstenite::Message) -> Self { fn from(value: tokio_tungstenite::tungstenite::Message) -> Self {
Self(value.to_string()) Self(value.to_string())
} }
} }

View File

@ -2,36 +2,9 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this // License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
use futures_util::{ use ws_stream_wasm::WsMessage;
stream::{SplitSink, SplitStream},
StreamExt,
};
use ws_stream_wasm::*;
use crate::errors::VoiceGatewayError;
use crate::voice::gateway::VoiceGatewayMessage; use crate::voice::gateway::VoiceGatewayMessage;
#[derive(Debug, Clone)]
pub struct WasmBackend;
// These could be made into inherent associated types when that's stabilized
pub type WasmSink = SplitSink<WsStream, WsMessage>;
pub type WasmStream = SplitStream<WsStream>;
impl WasmBackend {
pub async fn connect(websocket_url: &str) -> Result<(WasmSink, WasmStream), VoiceGatewayError> {
let (_, websocket_stream) = match WsMeta::connect(websocket_url, None).await {
Ok(stream) => Ok(stream),
Err(e) => Err(VoiceGatewayError::CannotConnect {
error: e.to_string(),
}),
}?;
Ok(websocket_stream.split())
}
}
impl From<VoiceGatewayMessage> for WsMessage { impl From<VoiceGatewayMessage> for WsMessage {
fn from(message: VoiceGatewayMessage) -> Self { fn from(message: VoiceGatewayMessage) -> Self {
Self::Text(message.0) Self::Text(message.0)

View File

@ -11,6 +11,9 @@ use tokio::sync::Mutex;
use futures_util::SinkExt; use futures_util::SinkExt;
use futures_util::StreamExt; use futures_util::StreamExt;
use crate::gateway::Sink;
use crate::gateway::Stream;
use crate::gateway::WebSocketBackend;
use crate::{ use crate::{
errors::VoiceGatewayError, errors::VoiceGatewayError,
gateway::GatewayEvent, gateway::GatewayEvent,
@ -21,14 +24,10 @@ use crate::{
VOICE_READY, VOICE_RESUME, VOICE_SELECT_PROTOCOL, VOICE_SESSION_DESCRIPTION, VOICE_READY, VOICE_RESUME, VOICE_SELECT_PROTOCOL, VOICE_SESSION_DESCRIPTION,
VOICE_SESSION_UPDATE, VOICE_SPEAKING, VOICE_SSRC_DEFINITION, VOICE_SESSION_UPDATE, VOICE_SPEAKING, VOICE_SSRC_DEFINITION,
}, },
voice::gateway::{ voice::gateway::{heartbeat::VoiceHeartbeatThreadCommunication, VoiceGatewayMessage},
heartbeat::VoiceHeartbeatThreadCommunication, VoiceGatewayMessage, WebSocketBackend,
},
}; };
use super::{ use super::{events::VoiceEvents, heartbeat::VoiceHeartbeatHandler, VoiceGatewayHandle};
events::VoiceEvents, heartbeat::VoiceHeartbeatHandler, Sink, Stream, VoiceGatewayHandle,
};
#[derive(Debug)] #[derive(Debug)]
pub struct VoiceGateway { pub struct VoiceGateway {
@ -48,7 +47,14 @@ impl VoiceGateway {
trace!("Created voice socket url: {}", processed_url.clone()); trace!("Created voice socket url: {}", processed_url.clone());
let (websocket_send, mut websocket_receive) = let (websocket_send, mut websocket_receive) =
WebSocketBackend::connect(&processed_url).await?; match WebSocketBackend::connect(&processed_url).await {
Ok(streams) => streams,
Err(e) => {
return Err(VoiceGatewayError::CannotConnect {
error: format!("{:?}", e),
})
}
};
let shared_websocket_send = Arc::new(Mutex::new(websocket_send)); let shared_websocket_send = Arc::new(Mutex::new(websocket_send));

View File

@ -11,13 +11,16 @@ use futures_util::SinkExt;
use serde_json::json; use serde_json::json;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::types::{ use crate::{
gateway::Sink,
types::{
SelectProtocol, Speaking, SsrcDefinition, VoiceGatewaySendPayload, VoiceIdentify, SelectProtocol, Speaking, SsrcDefinition, VoiceGatewaySendPayload, VoiceIdentify,
VOICE_BACKEND_VERSION, VOICE_IDENTIFY, VOICE_SELECT_PROTOCOL, VOICE_SPEAKING, VOICE_BACKEND_VERSION, VOICE_IDENTIFY, VOICE_SELECT_PROTOCOL, VOICE_SPEAKING,
VOICE_SSRC_DEFINITION, VOICE_SSRC_DEFINITION,
},
}; };
use super::{events::VoiceEvents, Sink, VoiceGatewayMessage}; use super::{events::VoiceEvents, VoiceGatewayMessage};
/// Represents a handle to a Voice Gateway connection. /// Represents a handle to a Voice Gateway connection.
/// Using this handle you can send Gateway Events directly. /// Using this handle you can send Gateway Events directly.

View File

@ -26,13 +26,11 @@ use tokio::sync::{
use tokio::task; use tokio::task;
use crate::{ use crate::{
gateway::heartbeat::HEARTBEAT_ACK_TIMEOUT, gateway::{heartbeat::HEARTBEAT_ACK_TIMEOUT, Sink},
types::{VoiceGatewaySendPayload, VOICE_HEARTBEAT, VOICE_HEARTBEAT_ACK}, types::{VoiceGatewaySendPayload, VOICE_HEARTBEAT, VOICE_HEARTBEAT_ACK},
voice::gateway::VoiceGatewayMessage, voice::gateway::VoiceGatewayMessage,
}; };
use super::Sink;
/// Handles sending heartbeats to the voice gateway in another thread /// Handles sending heartbeats to the voice gateway in another thread
#[allow(dead_code)] // FIXME: Remove this, once all fields of VoiceHeartbeatHandler are used #[allow(dead_code)] // FIXME: Remove this, once all fields of VoiceHeartbeatHandler are used
#[derive(Debug)] #[derive(Debug)]