diff --git a/Cargo.lock b/Cargo.lock index e15700b..b30d3d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -375,12 +375,14 @@ dependencies = [ "anyhow", "argon2", "axum", + "axum-extra", "axum-macros", "axum_session", "axum_session_sqlx", "base64 0.22.1", "derive_more", "dotenvy", + "futures", "http 1.1.0", "jsonwebtoken", "leptos", @@ -418,6 +420,7 @@ dependencies = [ "derive_more", "directories", "dotenvy", + "futures-util", "image", "interprocess", "open", @@ -430,6 +433,7 @@ dependencies = [ "thiserror", "time", "tokio", + "tokio-tungstenite", "toml", "tracing", "tracing-subscriber", @@ -469,6 +473,7 @@ checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" dependencies = [ "async-trait", "axum-core", + "base64 0.22.1", "bytes", "futures-util", "http 1.1.0", @@ -488,8 +493,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper 1.0.1", "tokio", + "tokio-tungstenite", "tower 0.5.1", "tower-layer", "tower-service", @@ -517,6 +524,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-extra" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73c3220b188aea709cf1b6c5f9b01c3bd936bb08bd2b5184a12b35ac8131b1f9" +dependencies = [ + "axum", + "axum-core", + "bytes", + "futures-util", + "headers", + "http 1.1.0", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "serde", + "tower 0.5.1", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-macros" version = "0.4.2" @@ -1338,6 +1368,12 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "data-encoding" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" + [[package]] name = "der" version = "0.7.9" @@ -2319,6 +2355,30 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "headers" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9" +dependencies = [ + "base64 0.21.7", + "bytes", + "headers-core", + "http 1.1.0", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http 1.1.0", +] + [[package]] name = "heck" version = "0.4.1" @@ -5962,6 +6022,22 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" +dependencies = [ + "futures-util", + "log", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tungstenite", + "webpki-roots", +] + [[package]] name = "tokio-util" version = "0.7.12" @@ -6207,6 +6283,26 @@ version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5902c5d130972a0000f60860bfbf46f7ca3db5391eddfedd1b8728bd9dc96c0e" +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "rand", + "rustls", + "rustls-pki-types", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "typed-builder" version = "0.18.2" @@ -6370,6 +6466,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf16_iter" version = "1.0.5" diff --git a/Cargo.toml b/Cargo.toml index 31757fd..d11a26a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,22 +29,27 @@ hydrate = ["leptos/hydrate", "leptos_meta/hydrate", "leptos_router/hydrate"] ssr = [ "dep:argon2", - "dep:dotenvy", - "dep:rand", - "dep:sha256", - "dep:jsonwebtoken", - "dep:tokio", - "dep:time", - "dep:tracing-subscriber", - "dep:leptos_axum", - "dep:lettre", - "dep:tera", - "dep:sqlx", + "dep:axum", + "dep:axum-extra", "dep:axum-macros", "dep:axum_session", "dep:axum_session_sqlx", + "dep:dotenvy", + "dep:futures", + "dep:jsonwebtoken", + "dep:leptos_axum", + "dep:lettre", + + "dep:rand", + "dep:sha256", + "dep:sqlx", + "dep:tokio", + "dep:time", + "dep:tracing-subscriber", + "dep:tera", + "dep:tower", "dep:tower-http", "dep:tower-layer", @@ -60,6 +65,7 @@ anyhow = { version = "1.0.89", optional = false } argon2 = { version = "0.5.3", optional = true } derive_more = { version = "1.0.0", features = ["full"], optional = false } dotenvy = { version = "0.15.7", optional = true } +futures = { version = "0.3.31", optional = true } rand = { version = "0.8.5", optional = true } serde = { version = "1.0.210", features = ["std", "derive"], optional = false } thiserror = { version = "1.0.64", optional = false } @@ -102,7 +108,8 @@ sqlx = { version = "0.8.2", default-features = false, features = [ ], optional = true } # Web -axum = { version = "0.7.7", optional = true } +axum = { version = "0.7.7", optional = true, features = ["ws"] } +axum-extra = { version = "0.9.4", optional = true, features = ["typed-header"] } axum-macros = { version = "0.4.2", optional = true } axum_session = { version = "0.14.0", optional = true } axum_session_sqlx = { version = "0.3.0", optional = true } diff --git a/avam-client/Cargo.toml b/avam-client/Cargo.toml index 5d00c85..fd29340 100644 --- a/avam-client/Cargo.toml +++ b/avam-client/Cargo.toml @@ -12,9 +12,12 @@ ctrlc = "3.4.5" derive_more = { version = "1.0", features = ["full"] } directories = "5.0" dotenvy = "0.15.7" +futures-util = { version = "0.3.31", default-features = false, features = [ + "sink", + "std", +] } image = "0.25" interprocess = { version = "2.2.1", features = ["tokio"] } -tauri-winrt-notification = "0.6.0" open = "5.3.0" rand = "0.8.5" reqwest = { version = "0.12.8", default-features = false, features = [ @@ -24,9 +27,13 @@ reqwest = { version = "0.12.8", default-features = false, features = [ serde = { version = "1", features = ["derive"] } serde_qs = "0.13.0" sha256 = "1.5.0" +tauri-winrt-notification = "0.6.0" thiserror = { version = "1.0" } time = "0.3.36" tokio = { version = "1.40.0", features = ["full"] } +tokio-tungstenite = { version = "0.24.0", features = [ + "rustls-tls-webpki-roots", +] } toml = "0.8" tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["time"] } diff --git a/avam-client/src/app.rs b/avam-client/src/app.rs index 2c44376..2adc5ed 100644 --- a/avam-client/src/app.rs +++ b/avam-client/src/app.rs @@ -56,10 +56,7 @@ impl App { return Ok(()); } - oauth::open_browser( - c.code_verifier().unwrap(), - c.code_challenge_method().unwrap(), - )?; + oauth::open_browser(c.clone())?; Ok(()) })?; @@ -70,7 +67,7 @@ impl App { c.set_token(None)?; c.set_open_browser(false)?; - let _ = s.send(Event::Ready { config: c.clone() }); + let _ = s.send(Event::Logout); Ok(()) })?; @@ -108,7 +105,7 @@ impl ApplicationHandler for App { if let Ok(event) = self.receiver.try_recv() { match event { - Event::Ready { .. } => { + Event::Logout | Event::Ready => { self.tray_icon .set_text(self.items.get("login").unwrap(), "Login") .unwrap(); @@ -117,7 +114,7 @@ impl ApplicationHandler for App { .set_enabled(self.items.get("forget").unwrap(), false) .unwrap(); } - Event::TokenReceived { .. } => { + Event::TokenReceived { .. } | Event::Connected => { self.tray_icon .set_text(self.items.get("login").unwrap(), "Open Avam") .unwrap(); @@ -138,9 +135,7 @@ impl ApplicationHandler for App { fn resumed(&mut self, _: &winit::event_loop::ActiveEventLoop) { let _ = self.tray_icon.build(); - let _ = self.sender.send(Event::Ready { - config: self.config.clone(), - }); + let _ = self.sender.send(Event::Ready); if !self.config.toast_shown() { let _ = Toast::new(crate::AVAM_APP_ID) diff --git a/avam-client/src/client.rs b/avam-client/src/client.rs new file mode 100644 index 0000000..0bbc278 --- /dev/null +++ b/avam-client/src/client.rs @@ -0,0 +1,97 @@ +use std::{borrow::Cow, time::Duration}; + +use futures_util::{SinkExt, StreamExt}; +use reqwest::StatusCode; +use tokio::{ + sync::broadcast::{Receiver, Sender}, + time::sleep, +}; +use tokio_tungstenite::{ + connect_async, + tungstenite::{ + self, + protocol::{frame::coding::CloseCode, CloseFrame}, + ClientRequestBuilder, + }, +}; + +use crate::{state_machine::Event, BASE_URL}; + +pub async fn start( + event_sender: Sender, + mut event_receiver: Receiver, +) -> Result<(), anyhow::Error> { + let mut writer = None; + + let uri: tungstenite::http::Uri = format!("{}/ws", BASE_URL.replace("https", "wss")).parse()?; + + loop { + if let Ok(event) = &event_receiver.try_recv() { + match event { + Event::TokenReceived { token } => { + let builder = ClientRequestBuilder::new(uri.clone()) + .with_header("Authorization", format!("Bearer {}", token)); + + let (socket, response) = connect_async(builder).await?; + + if response.status() != StatusCode::SWITCHING_PROTOCOLS { + tracing::error!("{:#?}", response); + continue; + } + + let (write, mut read) = socket.split(); + writer = Some(write); + + tokio::spawn(async move { + let message = match read.next().await { + Some(data) => match data { + Ok(message) => message, + Err(e) => { + tracing::error!("{:?}", e); + return; + } + }, + None => return, + }; + + let data = message.to_text(); + tracing::debug!("{:?}", data); + }); + + tracing::info!("Connected"); + let _ = event_sender.send(Event::Connected); + } + Event::Logout => { + if let Some(mut write) = writer { + write + .send(tungstenite::Message::Close(Some(CloseFrame { + code: CloseCode::Normal, + reason: Cow::from("User Logout"), + }))) + .await?; + writer = None; + tracing::debug!("Disconnected"); + event_sender.send(Event::Disconnected)?; + } + } + Event::Quit => { + tracing::info!("Shutting down Client"); + if let Some(mut write) = writer { + write + .send(tungstenite::Message::Close(Some(CloseFrame { + code: CloseCode::Normal, + reason: Cow::from("Application Shutdown"), + }))) + .await?; + tracing::debug!("Disconnected"); + } + break; + } + _ => {} + } + } + sleep(Duration::from_millis(100)).await; + } + tracing::info!("Client Shutdown"); + Ok(()) +} diff --git a/avam-client/src/config.rs b/avam-client/src/config.rs index 6ca8178..984efe7 100644 --- a/avam-client/src/config.rs +++ b/avam-client/src/config.rs @@ -101,10 +101,6 @@ impl Config { pub fn code_verifier(&self) -> Option { self.code_verifier.read().unwrap().clone() } - - pub fn code_challenge_method(&self) -> Option { - self.code_challenge_method.read().unwrap().clone() - } } impl Config { diff --git a/avam-client/src/main.rs b/avam-client/src/main.rs index e4e4e9c..d869fc0 100644 --- a/avam-client/src/main.rs +++ b/avam-client/src/main.rs @@ -2,6 +2,7 @@ #![allow(clippy::needless_return)] mod app; +mod client; mod config; mod dirs; mod icon; @@ -18,7 +19,8 @@ use oauth::{start_code_listener, start_code_to_token}; use pipe::Pipe; use state_machine::Event; use tokio::task::JoinSet; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use tracing::Level; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer}; pub static AVAM_APP_ID: &str = "AvamToast-ECEB71694A5E6105"; pub static BASE_URL: &str = "https://avam.avii.nl"; @@ -42,7 +44,8 @@ async fn main() -> Result<(), anyhow::Error> { init_logging()?; - let (event_sender, event_receiver) = tokio::sync::broadcast::channel(1); + // let (socket_sender, socket_receiver) = tokio::sync::broadcast::channel(1); + let (event_sender, event_receiver) = tokio::sync::broadcast::channel(10); let args = Arguments::parse(); if handle_single_instance(&args).await? { @@ -81,7 +84,7 @@ async fn main() -> Result<(), anyhow::Error> { // // Start the code listener let receiver = event_receiver.resubscribe(); - let (pipe_sender, pipe_receiver) = tokio::sync::broadcast::channel(100); + let (pipe_sender, pipe_receiver) = tokio::sync::broadcast::channel(10); futures.spawn(start_code_listener(pipe_sender, receiver)); // Start token listener @@ -90,6 +93,20 @@ async fn main() -> Result<(), anyhow::Error> { let receiver = event_receiver.resubscribe(); futures.spawn(start_code_to_token(c, pipe_receiver, sender, receiver)); + // Start the websocket client + // The socket client will just sit there until TokenReceivedEvent comes in to authenticate with the socket server + // The server needs to not accept any messages until the authentication is verified + let sender = event_sender.clone(); + let receiver = event_receiver.resubscribe(); + futures.spawn(client::start(sender, receiver)); + + // We need 2 way channels (2 channels, both with tx/rx) to send data from the socket to simconnect and back + + // Start the simconnect listener + // The simconnect sends data to the webscoket + // It also receives data from the websocket to do things like set plane id and fuel and such things + // If possible even position + // Start the Tray Icon let c = config.clone(); let sender = event_sender.clone(); @@ -230,7 +247,9 @@ fn init_logging() -> Result<(), anyhow::Error> { #[cfg(not(debug_assertions))] let file = File::options().append(true).open(&log_file)?; - let fmt = tracing_subscriber::fmt::layer(); + let fmt = tracing_subscriber::fmt::layer().with_filter(tracing_subscriber::filter::filter_fn( + |metadata| metadata.level() < &Level::TRACE, + )); #[cfg(not(debug_assertions))] let fmt = fmt.with_ansi(false).with_writer(Arc::new(file)); diff --git a/avam-client/src/oauth.rs b/avam-client/src/oauth.rs index d310f6b..e1ec849 100644 --- a/avam-client/src/oauth.rs +++ b/avam-client/src/oauth.rs @@ -1,4 +1,5 @@ use std::time::Duration; +use thiserror::Error; use tokio::{ sync::broadcast::{Receiver, Sender}, @@ -6,13 +7,30 @@ use tokio::{ }; use crate::{ - config::Config, models::*, pipe::Pipe, state_machine::Event, BASE_URL, CLIENT_ID, REDIRECT_URI, + config::{Config, ConfigError}, + models::*, + pipe::Pipe, + state_machine::Event, + BASE_URL, CLIENT_ID, REDIRECT_URI, }; -pub fn open_browser( - code_verifier: CodeVerifier, - code_challenge_method: CodeChallengeMethod, -) -> Result<(), anyhow::Error> { +#[derive(Debug, Error)] +pub enum OpenBrowserError { + #[error(transparent)] + SerdeQs(#[from] serde_qs::Error), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + Config(#[from] ConfigError), +} + +pub fn open_browser(config: Config) -> Result<(), OpenBrowserError> { + let code_verifier = CodeVerifier::new(); + let code_challenge_method = CodeChallengeMethod::Sha256; + + config.set_code_verifier(Some(code_verifier.clone()))?; + config.set_code_challenge_method(Some(code_challenge_method.clone()))?; + let code_challenge = match code_challenge_method { CodeChallengeMethod::Plain => { use base64::prelude::*; @@ -84,8 +102,11 @@ pub async fn start_code_to_token( .await?; let response: AuthorizationCodeResponse = response.json().await?; + let token = response.token(); - event_sender.send(Event::TokenReceived { token: response.token() })?; + config.set_token(Some(token.clone()))?; + + event_sender.send(Event::TokenReceived { token })?; } } } diff --git a/avam-client/src/simconnect.rs b/avam-client/src/simconnect.rs new file mode 100644 index 0000000..4515a03 --- /dev/null +++ b/avam-client/src/simconnect.rs @@ -0,0 +1,29 @@ +pub struct Client { + // whatever we need +} + +impl Client { + pub fn new() -> Self { + Self { + // websocket receiver + // websocket sender + // simconnect client handle + } + } + + pub fn run() -> Result<(), anyhow::Error> { + loop { + tokio::select! { + // we can either get a message from the websocket to pass to simconnect + + // we can get a message from simconnect to pass to the websocket + + // or we get a quit event from the event channel + } + } + } +} + +pub async fn start() -> Result<(), anyhow::Error> { + Client::new().run().await? +} diff --git a/avam-client/src/state_machine.rs b/avam-client/src/state_machine.rs index c9212f8..493769d 100644 --- a/avam-client/src/state_machine.rs +++ b/avam-client/src/state_machine.rs @@ -1,4 +1,9 @@ -use tokio::sync::broadcast::{Receiver, Sender}; +use std::time::Duration; + +use tokio::{ + sync::broadcast::{Receiver, Sender}, + time::sleep, +}; use crate::{ config::Config, @@ -6,117 +11,75 @@ use crate::{ oauth, }; -#[derive(Debug, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum State { Init, - AppStart { - config: Config, - }, - Authenticate { - open_browser: bool, - code_verifier: CodeVerifier, - code_challenge_method: CodeChallengeMethod, - }, - Connect { - token: String, - }, + AppStart, + Authenticate, + Connect { token: String }, WaitForSim, InSim, } #[derive(Debug, Clone, PartialEq)] pub enum Event { - Ready { - config: Config, - }, - StartAuthenticate { - open_browser: bool, - code_verifier: CodeVerifier, - code_challenge_method: CodeChallengeMethod, - }, // should not be string - TokenReceived { - token: String, - }, // AppStart and Authenticate can fire off TokenReceived to transition into Connect + Ready, + StartAuthenticate, // should not be string + TokenReceived { token: String }, // AppStart and Authenticate can fire off TokenReceived to transition into Connect Connected, // Once connected to the socket, and properly authenticated, fire off Connected to transition to WaitForSim Disconnected, // If for whatever reason we're disconnected from the backend, we need to transition back to Connect SimConnected, // SimConnect is connected, we're in the world and ready to send data, transition to Running SimDisconnected, // SimConnect is disconnected, we've finished the flight and exited back to the menu, transition back to WaitForSim + Logout, Quit, } impl State { pub async fn next(self, event: Event) -> State { - match (self, event) { + match (self.clone(), event.clone()) { // (Current State, SomeEvent) => NextState - (_, Event::Ready { config }) => State::AppStart { config }, - ( - State::AppStart { .. }, - Event::StartAuthenticate { - open_browser, - code_verifier, - code_challenge_method, - }, - ) => Self::Authenticate { - open_browser, - code_verifier, - code_challenge_method, - }, // Goto Authenticate + (_, Event::Ready) => State::AppStart, + (_, Event::Logout) => State::AppStart, + (_, Event::StartAuthenticate) => Self::Authenticate, // Goto Authenticate - (State::AppStart { .. }, Event::TokenReceived { token }) => State::Connect { token }, - (State::Authenticate { .. }, Event::TokenReceived { token }) => { - State::Connect { token } - } + (_, Event::TokenReceived { token }) => State::Connect { token }, - (State::Connect { .. }, Event::Connected) => todo!(), // Goto WaitForSim + (_, Event::Connected) => State::WaitForSim, // Goto WaitForSim - (State::WaitForSim, Event::SimConnected) => todo!(), // Goto InSim + (_, Event::SimConnected) => todo!(), // Goto InSim - (_, Event::Disconnected) => todo!(), // Goto Connect - (State::InSim, Event::SimDisconnected) => todo!(), // Goto WaitForSim + (_, Event::Disconnected) => State::AppStart, // Goto Connect + (_, Event::SimDisconnected) => State::WaitForSim, // Goto WaitForSim (_, Event::Quit) => todo!(), // All events can go into quit, to shutdown the application - - _ => panic!("Invalid state transition"), } } - pub async fn run(&self, signal: Sender) -> Result<(), anyhow::Error> { + pub async fn run(&self, signal: Sender, config: Config) -> Result<(), anyhow::Error> { match self { State::Init => Ok(()), - State::AppStart { config } => { + State::AppStart => { if let Some(token) = config.token() { signal.send(Event::TokenReceived { token: token.to_string(), })?; } else { - let open_browser = config.open_browser(); - let code_verifier = CodeVerifier::new(); - let code_challenge_method = CodeChallengeMethod::Sha256; - - config.set_code_verifier(Some(code_verifier.clone()))?; - config.set_code_challenge_method(Some(code_challenge_method.clone()))?; - - signal.send(Event::StartAuthenticate { - open_browser, - code_verifier, - code_challenge_method, - })?; + signal.send(Event::StartAuthenticate)?; } Ok(()) } - State::Authenticate { - open_browser, - code_verifier, - code_challenge_method, - } => { - if *open_browser { - oauth::open_browser(code_verifier.clone(), code_challenge_method.clone())?; + State::Authenticate => { + if config.open_browser() { + oauth::open_browser(config.clone())?; } Ok(()) } State::Connect { .. } => Ok(()), - State::WaitForSim => Ok(()), + State::WaitForSim => { + tracing::info!("Waiting for sim!"); + Ok(()) + } State::InSim => Ok(()), } } @@ -129,25 +92,20 @@ pub async fn start( ) -> Result<(), anyhow::Error> { let mut state = State::Init; - state.run(event_sender.clone()).await?; + state.run(event_sender.clone(), config.clone()).await?; loop { - if let Ok(event) = event_receiver.recv().await { + if let Ok(event) = event_receiver.try_recv() { + state = state.next(event.clone()).await; + + state.run(event_sender.clone(), config.clone()).await?; + if event == Event::Quit { tracing::info!("Shutting down State Machine"); break; } - - state = state.next(event).await; - - // before run - if let State::Connect { token } = &state { - // before run Connect, save the given token in config - config.set_token(Some(token.clone()))?; - } - - state.run(event_sender.clone()).await?; } + sleep(Duration::from_millis(100)).await; } tracing::info!("State Machine Shutdown"); diff --git a/src/bin/server/main.rs b/src/bin/server/main.rs index 3072a2f..651c269 100644 --- a/src/bin/server/main.rs +++ b/src/bin/server/main.rs @@ -30,7 +30,7 @@ async fn main() -> anyhow::Result<()> { let api_service = api::Service::new(postgres.clone(), dangerous_lettre); - let app_state = AppState::new(api_service).await; + let app_state = AppState::new(api_service, config.clone()).await; HttpServer::new(app_state, postgres.pool()) .await? diff --git a/src/lib/domain/api/ports/api_service.rs b/src/lib/domain/api/ports/api_service.rs index 66cdf1f..9c961f5 100644 --- a/src/lib/domain/api/ports/api_service.rs +++ b/src/lib/domain/api/ports/api_service.rs @@ -1,5 +1,6 @@ use crate::{ - domain::api::models::oauth::*, inbound::http::handlers::oauth::AuthorizationCodeRequest, + domain::api::models::oauth::*, + inbound::http::handlers::oauth::{AuthorizationCodeRequest, VerifyClientAuthorizationRequest}, }; use super::super::models::user::*; @@ -56,5 +57,13 @@ pub trait ApiService: Clone + Send + Sync + 'static { fn create_token( &self, req: AuthorizationCodeRequest, - ) -> impl Future, TokenError>> + Send; + ) -> impl Future> + Send; + + /// --- + /// WS + /// --- + fn verify_client_authorization( + &self, + req: VerifyClientAuthorizationRequest, + ) -> impl Future> + Send; } diff --git a/src/lib/domain/api/service.rs b/src/lib/domain/api/service.rs index 77dadbb..12dc36e 100644 --- a/src/lib/domain/api/service.rs +++ b/src/lib/domain/api/service.rs @@ -6,6 +6,7 @@ use axum_session::SessionAnySession; use crate::inbound::http::handlers::oauth::AuthorizationCodeRequest; use crate::inbound::http::handlers::oauth::GrantType; +use crate::inbound::http::handlers::oauth::VerifyClientAuthorizationRequest; use super::models::oauth::Client; use super::models::oauth::*; @@ -228,7 +229,7 @@ where async fn create_token( &self, req: AuthorizationCodeRequest, - ) -> Result, TokenError> { + ) -> Result { if req.grant_type() != GrantType::AuthorizationCode { return Err(TokenError::InvalidRequest); } @@ -265,6 +266,24 @@ where let _ = self.repo.delete_token(req.code()).await; - Ok(Some(token)) + Ok(token) + } + + async fn verify_client_authorization( + &self, + req: VerifyClientAuthorizationRequest, + ) -> Result { + let user_id = req.user_id(); + let client_id = req.client_id(); + + if !self.repo.is_authorized_client(user_id, client_id).await? { + return Err(anyhow::anyhow!("Unauthorized")); + } + + let Some(user) = self.repo.find_user_by_id(user_id).await? else { + return Err(anyhow::anyhow!("Unauthorized")); + }; + + Ok(user) } } diff --git a/src/lib/inbound/http.rs b/src/lib/inbound/http.rs index 7b172e3..e01fa0e 100644 --- a/src/lib/inbound/http.rs +++ b/src/lib/inbound/http.rs @@ -5,14 +5,12 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use anyhow::Context; use axum::{ - extract::ConnectInfo, - routing::{get, post}, - Extension, + extract::ConnectInfo, routing::{any, get, post}, Extension }; use axum_session::{SessionAnyPool, SessionConfig, SessionLayer, SessionStore}; use axum_session_sqlx::SessionPgPool; use handlers::{ - fileserv::file_and_error_handler, leptos::{leptos_routes_handler, server_fn_handler}, oauth, user::activate_account + fileserv::file_and_error_handler, leptos::{leptos_routes_handler, server_fn_handler}, oauth, user::activate_account, websocket::ws_handler }; use leptos_axum::{generate_route_list, LeptosRoutes}; use state::AppState; @@ -57,6 +55,7 @@ impl HttpServer { ); let router = axum::Router::new() + .route("/ws", any(ws_handler)) .nest("/oauth2", oauth::routes()) .route("/auth/activate/:token", get(activate_account)) .route("/api/*fn_name", post(server_fn_handler)) @@ -87,3 +86,4 @@ impl HttpServer { Ok(()) } } + diff --git a/src/lib/inbound/http/handlers/mod.rs b/src/lib/inbound/http/handlers/mod.rs index b2a2711..b48a235 100644 --- a/src/lib/inbound/http/handlers/mod.rs +++ b/src/lib/inbound/http/handlers/mod.rs @@ -2,3 +2,4 @@ pub mod fileserv; pub mod leptos; pub mod oauth; pub mod user; +pub mod websocket; diff --git a/src/lib/inbound/http/handlers/oauth.rs b/src/lib/inbound/http/handlers/oauth.rs index 60d3cc4..b26498d 100644 --- a/src/lib/inbound/http/handlers/oauth.rs +++ b/src/lib/inbound/http/handlers/oauth.rs @@ -120,6 +120,26 @@ impl AuthorizationCodeRequest { } } +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct VerifyClientAuthorizationRequest { + user_id: uuid::Uuid, + client_id: uuid::Uuid, +} + +impl VerifyClientAuthorizationRequest { + pub fn new(user_id: uuid::Uuid, client_id: uuid::Uuid) -> Self { + Self { client_id, user_id } + } + + pub fn user_id(&self) -> uuid::Uuid { + self.user_id + } + + pub fn client_id(&self) -> uuid::Uuid { + self.client_id + } +} + #[derive(Debug, Serialize, Deserialize)] pub struct TokenClaims where @@ -180,7 +200,11 @@ where let iat = now.unix_timestamp() as usize; let exp = (now + time::Duration::days(30)).unix_timestamp() as usize; - let claims = TokenClaims { sub, iat, exp }; + let claims = TokenClaims { + sub: serde_qs::to_string(&sub)?, + iat, + exp, + }; let token = encode( &Header::default(), diff --git a/src/lib/inbound/http/handlers/websocket.rs b/src/lib/inbound/http/handlers/websocket.rs new file mode 100644 index 0000000..155ad33 --- /dev/null +++ b/src/lib/inbound/http/handlers/websocket.rs @@ -0,0 +1,225 @@ +use std::{borrow::Cow, net::SocketAddr, ops::ControlFlow}; + +use axum::{ + extract::{ + ws::{CloseFrame, Message, WebSocket}, + ConnectInfo, State, WebSocketUpgrade, + }, + response::IntoResponse, +}; +use axum_extra::{ + headers::{self, authorization::Bearer}, + TypedHeader, +}; +use futures::{SinkExt, StreamExt}; +use http::StatusCode; +use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; + +use crate::{ + domain::api::{ + ports::ApiService, + prelude::{TokenSubject, User}, + }, + inbound::http::{ + handlers::oauth::{TokenClaims, VerifyClientAuthorizationRequest}, + state::AppState, + }, +}; + +pub async fn ws_handler( + State(app_state): State>, + ws: WebSocketUpgrade, + auth_token: Option>>, + ConnectInfo(addr): ConnectInfo, +) -> Result { + let auth_token = match auth_token { + Some(TypedHeader(token)) => Some(token.token().to_string()), + None => return Err(StatusCode::UNAUTHORIZED), + }; + + let Some(auth_token) = auth_token else { + return Err(StatusCode::UNAUTHORIZED); + }; + + let jwt_secret = &app_state.config().jwt_secret; + + let claims = decode::>( + &auth_token, + &DecodingKey::from_secret(jwt_secret.as_ref()), + &Validation::new(Algorithm::HS256), + ) + .map_err(|e| { + tracing::error!("Unable to decode token: {}\n{:?}", auth_token, e); + StatusCode::UNAUTHORIZED + })? + .claims; + + let token_subject: TokenSubject = serde_qs::from_str(&claims.sub).map_err(|e| { + tracing::error!("Unable to parse Token Subject: {}\n{:?}", &claims.sub, e); + StatusCode::BAD_REQUEST + })?; + + let user_id = token_subject.user_id(); + let client_id = token_subject.client_id(); + + let app = app_state.api_service(); + let Ok(user) = app + .verify_client_authorization(VerifyClientAuthorizationRequest::new(user_id, client_id)) + .await + else { + return Err(StatusCode::UNAUTHORIZED); + }; + + Ok(ws.on_upgrade(move |socket| handle_socket(socket, user, addr))) +} + +/// Actual websocket statemachine (one will be spawned per connection) +async fn handle_socket(mut socket: WebSocket, user: User, who: SocketAddr) { + // send a ping (unsupported by some browsers) just to kick things off and get a response + if socket + .send(Message::Text(format!("Hello {}!", user.email()))) + .await + .is_ok() + { + tracing::debug!("Pinged {who}..."); + } else { + tracing::debug!("Could not send ping {who}!"); + // no Error here since the only thing we can do is to close the connection. + // If we can not send messages, there is no way to salvage the statemachine anyway. + return; + } + + // receive single message from a client (we can either receive or send with socket). + // this will likely be the Pong for our Ping or a hello message from client. + // waiting for message from a client will block this task, but will not block other client's + // connections. + if let Some(msg) = socket.recv().await { + if let Ok(msg) = msg { + if process_message(msg, who).is_break() { + return; + } + } else { + tracing::debug!("client {who} abruptly disconnected"); + return; + } + } + + // Since each client gets individual statemachine, we can pause handling + // when necessary to wait for some external event (in this case illustrated by sleeping). + // Waiting for this client to finish getting its greetings does not prevent other clients from + // connecting to server and receiving their greetings. + for i in 1..5 { + if socket + .send(Message::Text(format!("Hi {i} times!"))) + .await + .is_err() + { + tracing::debug!("client {who} abruptly disconnected"); + return; + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + // By splitting socket we can send and receive at the same time. In this example we will send + // unsolicited messages to client based on some sort of server's internal event (i.e .timer). + let (mut sender, mut receiver) = socket.split(); + + // Spawn a task that will push several messages to the client (does not matter what client does) + let mut send_task = tokio::spawn(async move { + let n_msg = 20; + for i in 0..n_msg { + // In case of any websocket error, we exit. + if sender + .send(Message::Text(format!("Server message {i} ..."))) + .await + .is_err() + { + return i; + } + + tokio::time::sleep(std::time::Duration::from_millis(300)).await; + } + + tracing::debug!("Sending close to {who}..."); + if let Err(e) = sender + .send(Message::Close(Some(CloseFrame { + code: axum::extract::ws::close_code::NORMAL, + reason: Cow::from("Goodbye"), + }))) + .await + { + tracing::debug!("Could not send Close due to {e}, probably it is ok?"); + } + n_msg + }); + + // This second task will receive messages from client and print them on server console + let mut recv_task = tokio::spawn(async move { + let mut cnt = 0; + while let Some(Ok(msg)) = receiver.next().await { + cnt += 1; + // print message and break if instructed to do so + if process_message(msg, who).is_break() { + break; + } + } + cnt + }); + + // If any one of the tasks exit, abort the other. + tokio::select! { + rv_a = (&mut send_task) => { + match rv_a { + Ok(a) => tracing::debug!("{a} messages sent to {who}"), + Err(a) => tracing::debug!("Error sending messages {a:?}") + } + recv_task.abort(); + }, + rv_b = (&mut recv_task) => { + match rv_b { + Ok(b) => tracing::debug!("Received {b} messages"), + Err(b) => tracing::debug!("Error receiving messages {b:?}") + } + send_task.abort(); + } + } + + // returning from the handler closes the websocket connection + tracing::debug!("Websocket context {who} destroyed"); +} + +/// helper to print contents of messages to stdout. Has special treatment for Close. +fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> { + match msg { + Message::Text(t) => { + tracing::debug!(">>> {who} sent str: {t:?}"); + } + Message::Binary(d) => { + tracing::debug!(">>> {} sent {} bytes: {:?}", who, d.len(), d); + } + Message::Close(c) => { + if let Some(cf) = c { + tracing::debug!( + ">>> {} sent close with code {} and reason `{}`", + who, + cf.code, + cf.reason + ); + } else { + tracing::debug!(">>> {who} somehow sent close message without CloseFrame"); + } + return ControlFlow::Break(()); + } + + Message::Pong(v) => { + tracing::debug!(">>> {who} sent pong with {v:?}"); + } + // You should never need to manually handle Message::Ping, as axum's websocket library + // will do so for you automagically by replying with Pong and copying the v according to + // spec. But if you need the contents of the pings you can see them here. + Message::Ping(v) => { + tracing::debug!(">>> {who} sent ping with {v:?}"); + } + } + ControlFlow::Continue(()) +} diff --git a/src/lib/inbound/http/state.rs b/src/lib/inbound/http/state.rs index 0c830f5..6eeee95 100644 --- a/src/lib/inbound/http/state.rs +++ b/src/lib/inbound/http/state.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use axum::extract::FromRef; use leptos::get_configuration; -use crate::domain::api::ports::ApiService; +use crate::{config::Config, domain::api::ports::ApiService}; #[derive(Debug, Clone)] /// The global application state shared between all request handlers. @@ -12,6 +12,7 @@ where S: ApiService, { pub leptos_options: leptos::LeptosOptions, + config: Arc, api_service: Arc, } @@ -19,13 +20,18 @@ impl AppState where S: ApiService, { - pub async fn new(api_service: S) -> Self { + pub async fn new(api_service: S, config: Config) -> Self { Self { + config: Arc::new(config), leptos_options: get_configuration(None).await.unwrap().leptos_options, api_service: Arc::new(api_service), } } + pub fn config(&self) -> Arc { + self.config.clone() + } + pub fn api_service(&self) -> Arc { self.api_service.clone() } diff --git a/style/main.scss b/style/main.scss index 45912ee..77b4412 100644 --- a/style/main.scss +++ b/style/main.scss @@ -1444,10 +1444,18 @@ html { margin-top: auto; } +.block { + display: block; +} + .flex { display: flex; } +.contents { + display: contents; +} + .hidden { display: none; }