From 98c842fd2504641d120972ad635fe04aada0e47c Mon Sep 17 00:00:00 2001 From: bitfl0wer Date: Mon, 26 Aug 2024 12:53:32 +0200 Subject: [PATCH] Match scheme for "ws" or "wss" and choose whether to connect with TLS connector for tungstenite --- src/gateway/backends/tungstenite.rs | 88 ++++++++++++++++++----------- 1 file changed, 56 insertions(+), 32 deletions(-) diff --git a/src/gateway/backends/tungstenite.rs b/src/gateway/backends/tungstenite.rs index 34dc825..4464f8d 100644 --- a/src/gateway/backends/tungstenite.rs +++ b/src/gateway/backends/tungstenite.rs @@ -9,8 +9,10 @@ use futures_util::{ }; use tokio::net::TcpStream; use tokio_tungstenite::{ - connect_async_tls_with_config, tungstenite, Connector, MaybeTlsStream, WebSocketStream, + connect_async_tls_with_config, connect_async_with_config, tungstenite, Connector, + MaybeTlsStream, WebSocketStream, }; +use url::Url; use crate::gateway::{GatewayMessage, RawGatewayMessage}; @@ -32,38 +34,60 @@ impl TungsteniteBackend { pub async fn connect( websocket_url: &str, ) -> Result<(TungsteniteSink, TungsteniteStream), TungsteniteBackendError> { - let certs = webpki_roots::TLS_SERVER_ROOTS; - let roots = rustls::RootCertStore { - roots: certs - .iter() - .map(|cert| { - rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( - cert.subject.to_vec(), - cert.subject_public_key_info.to_vec(), - cert.name_constraints.as_ref().map(|der| der.to_vec()), - ) - }) - .collect(), - }; - let (websocket_stream, _) = match connect_async_tls_with_config( - websocket_url, - None, - false, - Some(Connector::Rustls( - rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(roots) - .with_no_client_auth() - .into(), - )), - ) - .await - { - Ok(websocket_stream) => websocket_stream, - Err(e) => return Err(TungsteniteBackendError::TungsteniteError { error: e }), - }; + let websocket_url_parsed = + Url::parse(websocket_url).map_err(|_| TungsteniteBackendError::TungsteniteError { + error: tungstenite::error::Error::Url( + tungstenite::error::UrlError::UnsupportedUrlScheme, + ), + })?; + if websocket_url_parsed.scheme() == "ws" { + let (websocket_stream, _) = + match connect_async_with_config(websocket_url, None, false).await { + Ok(websocket_stream) => websocket_stream, + Err(e) => return Err(TungsteniteBackendError::TungsteniteError { error: e }), + }; - Ok(websocket_stream.split()) + Ok(websocket_stream.split()) + } else if websocket_url_parsed.scheme() == "wss" { + let certs = webpki_roots::TLS_SERVER_ROOTS; + let roots = rustls::RootCertStore { + roots: certs + .iter() + .map(|cert| { + rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( + cert.subject.to_vec(), + cert.subject_public_key_info.to_vec(), + cert.name_constraints.as_ref().map(|der| der.to_vec()), + ) + }) + .collect(), + }; + let (websocket_stream, _) = match connect_async_tls_with_config( + websocket_url, + None, + false, + Some(Connector::Rustls( + rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(roots) + .with_no_client_auth() + .into(), + )), + ) + .await + { + Ok(websocket_stream) => websocket_stream, + Err(e) => return Err(TungsteniteBackendError::TungsteniteError { error: e }), + }; + + Ok(websocket_stream.split()) + } else { + Err(TungsteniteBackendError::TungsteniteError { + error: tungstenite::error::Error::Url( + tungstenite::error::UrlError::UnsupportedUrlScheme, + ), + }) + } } }