websocket

This commit is contained in:
2024-10-18 18:20:44 +02:00
parent ea88c755b5
commit 5e651b382d
19 changed files with 654 additions and 131 deletions

102
Cargo.lock generated
View File

@@ -375,12 +375,14 @@ dependencies = [
"anyhow",
"argon2",
"axum",
"axum-extra",
"axum-macros",
"axum_session",
"axum_session_sqlx",
"base64 0.22.1",
"derive_more",
"dotenvy",
"futures",
"http 1.1.0",
"jsonwebtoken",
"leptos",
@@ -418,6 +420,7 @@ dependencies = [
"derive_more",
"directories",
"dotenvy",
"futures-util",
"image",
"interprocess",
"open",
@@ -430,6 +433,7 @@ dependencies = [
"thiserror",
"time",
"tokio",
"tokio-tungstenite",
"toml",
"tracing",
"tracing-subscriber",
@@ -469,6 +473,7 @@ checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae"
dependencies = [
"async-trait",
"axum-core",
"base64 0.22.1",
"bytes",
"futures-util",
"http 1.1.0",
@@ -488,8 +493,10 @@ dependencies = [
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sha1",
"sync_wrapper 1.0.1",
"tokio",
"tokio-tungstenite",
"tower 0.5.1",
"tower-layer",
"tower-service",
@@ -517,6 +524,29 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum-extra"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73c3220b188aea709cf1b6c5f9b01c3bd936bb08bd2b5184a12b35ac8131b1f9"
dependencies = [
"axum",
"axum-core",
"bytes",
"futures-util",
"headers",
"http 1.1.0",
"http-body",
"http-body-util",
"mime",
"pin-project-lite",
"serde",
"tower 0.5.1",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-macros"
version = "0.4.2"
@@ -1338,6 +1368,12 @@ dependencies = [
"parking_lot_core",
]
[[package]]
name = "data-encoding"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2"
[[package]]
name = "der"
version = "0.7.9"
@@ -2319,6 +2355,30 @@ dependencies = [
"hashbrown 0.14.5",
]
[[package]]
name = "headers"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9"
dependencies = [
"base64 0.21.7",
"bytes",
"headers-core",
"http 1.1.0",
"httpdate",
"mime",
"sha1",
]
[[package]]
name = "headers-core"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4"
dependencies = [
"http 1.1.0",
]
[[package]]
name = "heck"
version = "0.4.1"
@@ -5962,6 +6022,22 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9"
dependencies = [
"futures-util",
"log",
"rustls",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tungstenite",
"webpki-roots",
]
[[package]]
name = "tokio-util"
version = "0.7.12"
@@ -6207,6 +6283,26 @@ version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5902c5d130972a0000f60860bfbf46f7ca3db5391eddfedd1b8728bd9dc96c0e"
[[package]]
name = "tungstenite"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http 1.1.0",
"httparse",
"log",
"rand",
"rustls",
"rustls-pki-types",
"sha1",
"thiserror",
"utf-8",
]
[[package]]
name = "typed-builder"
version = "0.18.2"
@@ -6370,6 +6466,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "utf16_iter"
version = "1.0.5"

View File

@@ -29,22 +29,27 @@ hydrate = ["leptos/hydrate", "leptos_meta/hydrate", "leptos_router/hydrate"]
ssr = [
"dep:argon2",
"dep:dotenvy",
"dep:rand",
"dep:sha256",
"dep:jsonwebtoken",
"dep:tokio",
"dep:time",
"dep:tracing-subscriber",
"dep:leptos_axum",
"dep:lettre",
"dep:tera",
"dep:sqlx",
"dep:axum",
"dep:axum-extra",
"dep:axum-macros",
"dep:axum_session",
"dep:axum_session_sqlx",
"dep:dotenvy",
"dep:futures",
"dep:jsonwebtoken",
"dep:leptos_axum",
"dep:lettre",
"dep:rand",
"dep:sha256",
"dep:sqlx",
"dep:tokio",
"dep:time",
"dep:tracing-subscriber",
"dep:tera",
"dep:tower",
"dep:tower-http",
"dep:tower-layer",
@@ -60,6 +65,7 @@ anyhow = { version = "1.0.89", optional = false }
argon2 = { version = "0.5.3", optional = true }
derive_more = { version = "1.0.0", features = ["full"], optional = false }
dotenvy = { version = "0.15.7", optional = true }
futures = { version = "0.3.31", optional = true }
rand = { version = "0.8.5", optional = true }
serde = { version = "1.0.210", features = ["std", "derive"], optional = false }
thiserror = { version = "1.0.64", optional = false }
@@ -102,7 +108,8 @@ sqlx = { version = "0.8.2", default-features = false, features = [
], optional = true }
# Web
axum = { version = "0.7.7", optional = true }
axum = { version = "0.7.7", optional = true, features = ["ws"] }
axum-extra = { version = "0.9.4", optional = true, features = ["typed-header"] }
axum-macros = { version = "0.4.2", optional = true }
axum_session = { version = "0.14.0", optional = true }
axum_session_sqlx = { version = "0.3.0", optional = true }

View File

@@ -12,9 +12,12 @@ ctrlc = "3.4.5"
derive_more = { version = "1.0", features = ["full"] }
directories = "5.0"
dotenvy = "0.15.7"
futures-util = { version = "0.3.31", default-features = false, features = [
"sink",
"std",
] }
image = "0.25"
interprocess = { version = "2.2.1", features = ["tokio"] }
tauri-winrt-notification = "0.6.0"
open = "5.3.0"
rand = "0.8.5"
reqwest = { version = "0.12.8", default-features = false, features = [
@@ -24,9 +27,13 @@ reqwest = { version = "0.12.8", default-features = false, features = [
serde = { version = "1", features = ["derive"] }
serde_qs = "0.13.0"
sha256 = "1.5.0"
tauri-winrt-notification = "0.6.0"
thiserror = { version = "1.0" }
time = "0.3.36"
tokio = { version = "1.40.0", features = ["full"] }
tokio-tungstenite = { version = "0.24.0", features = [
"rustls-tls-webpki-roots",
] }
toml = "0.8"
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["time"] }

View File

@@ -56,10 +56,7 @@ impl App {
return Ok(());
}
oauth::open_browser(
c.code_verifier().unwrap(),
c.code_challenge_method().unwrap(),
)?;
oauth::open_browser(c.clone())?;
Ok(())
})?;
@@ -70,7 +67,7 @@ impl App {
c.set_token(None)?;
c.set_open_browser(false)?;
let _ = s.send(Event::Ready { config: c.clone() });
let _ = s.send(Event::Logout);
Ok(())
})?;
@@ -108,7 +105,7 @@ impl ApplicationHandler for App {
if let Ok(event) = self.receiver.try_recv() {
match event {
Event::Ready { .. } => {
Event::Logout | Event::Ready => {
self.tray_icon
.set_text(self.items.get("login").unwrap(), "Login")
.unwrap();
@@ -117,7 +114,7 @@ impl ApplicationHandler for App {
.set_enabled(self.items.get("forget").unwrap(), false)
.unwrap();
}
Event::TokenReceived { .. } => {
Event::TokenReceived { .. } | Event::Connected => {
self.tray_icon
.set_text(self.items.get("login").unwrap(), "Open Avam")
.unwrap();
@@ -138,9 +135,7 @@ impl ApplicationHandler for App {
fn resumed(&mut self, _: &winit::event_loop::ActiveEventLoop) {
let _ = self.tray_icon.build();
let _ = self.sender.send(Event::Ready {
config: self.config.clone(),
});
let _ = self.sender.send(Event::Ready);
if !self.config.toast_shown() {
let _ = Toast::new(crate::AVAM_APP_ID)

97
avam-client/src/client.rs Normal file
View File

@@ -0,0 +1,97 @@
use std::{borrow::Cow, time::Duration};
use futures_util::{SinkExt, StreamExt};
use reqwest::StatusCode;
use tokio::{
sync::broadcast::{Receiver, Sender},
time::sleep,
};
use tokio_tungstenite::{
connect_async,
tungstenite::{
self,
protocol::{frame::coding::CloseCode, CloseFrame},
ClientRequestBuilder,
},
};
use crate::{state_machine::Event, BASE_URL};
pub async fn start(
event_sender: Sender<Event>,
mut event_receiver: Receiver<Event>,
) -> Result<(), anyhow::Error> {
let mut writer = None;
let uri: tungstenite::http::Uri = format!("{}/ws", BASE_URL.replace("https", "wss")).parse()?;
loop {
if let Ok(event) = &event_receiver.try_recv() {
match event {
Event::TokenReceived { token } => {
let builder = ClientRequestBuilder::new(uri.clone())
.with_header("Authorization", format!("Bearer {}", token));
let (socket, response) = connect_async(builder).await?;
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
tracing::error!("{:#?}", response);
continue;
}
let (write, mut read) = socket.split();
writer = Some(write);
tokio::spawn(async move {
let message = match read.next().await {
Some(data) => match data {
Ok(message) => message,
Err(e) => {
tracing::error!("{:?}", e);
return;
}
},
None => return,
};
let data = message.to_text();
tracing::debug!("{:?}", data);
});
tracing::info!("Connected");
let _ = event_sender.send(Event::Connected);
}
Event::Logout => {
if let Some(mut write) = writer {
write
.send(tungstenite::Message::Close(Some(CloseFrame {
code: CloseCode::Normal,
reason: Cow::from("User Logout"),
})))
.await?;
writer = None;
tracing::debug!("Disconnected");
event_sender.send(Event::Disconnected)?;
}
}
Event::Quit => {
tracing::info!("Shutting down Client");
if let Some(mut write) = writer {
write
.send(tungstenite::Message::Close(Some(CloseFrame {
code: CloseCode::Normal,
reason: Cow::from("Application Shutdown"),
})))
.await?;
tracing::debug!("Disconnected");
}
break;
}
_ => {}
}
}
sleep(Duration::from_millis(100)).await;
}
tracing::info!("Client Shutdown");
Ok(())
}

View File

@@ -101,10 +101,6 @@ impl Config {
pub fn code_verifier(&self) -> Option<CodeVerifier> {
self.code_verifier.read().unwrap().clone()
}
pub fn code_challenge_method(&self) -> Option<CodeChallengeMethod> {
self.code_challenge_method.read().unwrap().clone()
}
}
impl Config {

View File

@@ -2,6 +2,7 @@
#![allow(clippy::needless_return)]
mod app;
mod client;
mod config;
mod dirs;
mod icon;
@@ -18,7 +19,8 @@ use oauth::{start_code_listener, start_code_to_token};
use pipe::Pipe;
use state_machine::Event;
use tokio::task::JoinSet;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use tracing::Level;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer};
pub static AVAM_APP_ID: &str = "AvamToast-ECEB71694A5E6105";
pub static BASE_URL: &str = "https://avam.avii.nl";
@@ -42,7 +44,8 @@ async fn main() -> Result<(), anyhow::Error> {
init_logging()?;
let (event_sender, event_receiver) = tokio::sync::broadcast::channel(1);
// let (socket_sender, socket_receiver) = tokio::sync::broadcast::channel(1);
let (event_sender, event_receiver) = tokio::sync::broadcast::channel(10);
let args = Arguments::parse();
if handle_single_instance(&args).await? {
@@ -81,7 +84,7 @@ async fn main() -> Result<(), anyhow::Error> {
// // Start the code listener
let receiver = event_receiver.resubscribe();
let (pipe_sender, pipe_receiver) = tokio::sync::broadcast::channel(100);
let (pipe_sender, pipe_receiver) = tokio::sync::broadcast::channel(10);
futures.spawn(start_code_listener(pipe_sender, receiver));
// Start token listener
@@ -90,6 +93,20 @@ async fn main() -> Result<(), anyhow::Error> {
let receiver = event_receiver.resubscribe();
futures.spawn(start_code_to_token(c, pipe_receiver, sender, receiver));
// Start the websocket client
// The socket client will just sit there until TokenReceivedEvent comes in to authenticate with the socket server
// The server needs to not accept any messages until the authentication is verified
let sender = event_sender.clone();
let receiver = event_receiver.resubscribe();
futures.spawn(client::start(sender, receiver));
// We need 2 way channels (2 channels, both with tx/rx) to send data from the socket to simconnect and back
// Start the simconnect listener
// The simconnect sends data to the webscoket
// It also receives data from the websocket to do things like set plane id and fuel and such things
// If possible even position
// Start the Tray Icon
let c = config.clone();
let sender = event_sender.clone();
@@ -230,7 +247,9 @@ fn init_logging() -> Result<(), anyhow::Error> {
#[cfg(not(debug_assertions))]
let file = File::options().append(true).open(&log_file)?;
let fmt = tracing_subscriber::fmt::layer();
let fmt = tracing_subscriber::fmt::layer().with_filter(tracing_subscriber::filter::filter_fn(
|metadata| metadata.level() < &Level::TRACE,
));
#[cfg(not(debug_assertions))]
let fmt = fmt.with_ansi(false).with_writer(Arc::new(file));

View File

@@ -1,4 +1,5 @@
use std::time::Duration;
use thiserror::Error;
use tokio::{
sync::broadcast::{Receiver, Sender},
@@ -6,13 +7,30 @@ use tokio::{
};
use crate::{
config::Config, models::*, pipe::Pipe, state_machine::Event, BASE_URL, CLIENT_ID, REDIRECT_URI,
config::{Config, ConfigError},
models::*,
pipe::Pipe,
state_machine::Event,
BASE_URL, CLIENT_ID, REDIRECT_URI,
};
pub fn open_browser(
code_verifier: CodeVerifier,
code_challenge_method: CodeChallengeMethod,
) -> Result<(), anyhow::Error> {
#[derive(Debug, Error)]
pub enum OpenBrowserError {
#[error(transparent)]
SerdeQs(#[from] serde_qs::Error),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
Config(#[from] ConfigError),
}
pub fn open_browser(config: Config) -> Result<(), OpenBrowserError> {
let code_verifier = CodeVerifier::new();
let code_challenge_method = CodeChallengeMethod::Sha256;
config.set_code_verifier(Some(code_verifier.clone()))?;
config.set_code_challenge_method(Some(code_challenge_method.clone()))?;
let code_challenge = match code_challenge_method {
CodeChallengeMethod::Plain => {
use base64::prelude::*;
@@ -84,8 +102,11 @@ pub async fn start_code_to_token(
.await?;
let response: AuthorizationCodeResponse = response.json().await?;
let token = response.token();
event_sender.send(Event::TokenReceived { token: response.token() })?;
config.set_token(Some(token.clone()))?;
event_sender.send(Event::TokenReceived { token })?;
}
}
}

View File

@@ -0,0 +1,29 @@
pub struct Client {
// whatever we need
}
impl Client {
pub fn new() -> Self {
Self {
// websocket receiver
// websocket sender
// simconnect client handle
}
}
pub fn run() -> Result<(), anyhow::Error> {
loop {
tokio::select! {
// we can either get a message from the websocket to pass to simconnect
// we can get a message from simconnect to pass to the websocket
// or we get a quit event from the event channel
}
}
}
}
pub async fn start() -> Result<(), anyhow::Error> {
Client::new().run().await?
}

View File

@@ -1,4 +1,9 @@
use tokio::sync::broadcast::{Receiver, Sender};
use std::time::Duration;
use tokio::{
sync::broadcast::{Receiver, Sender},
time::sleep,
};
use crate::{
config::Config,
@@ -6,117 +11,75 @@ use crate::{
oauth,
};
#[derive(Debug, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub enum State {
Init,
AppStart {
config: Config,
},
Authenticate {
open_browser: bool,
code_verifier: CodeVerifier,
code_challenge_method: CodeChallengeMethod,
},
Connect {
token: String,
},
AppStart,
Authenticate,
Connect { token: String },
WaitForSim,
InSim,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Event {
Ready {
config: Config,
},
StartAuthenticate {
open_browser: bool,
code_verifier: CodeVerifier,
code_challenge_method: CodeChallengeMethod,
}, // should not be string
TokenReceived {
token: String,
}, // AppStart and Authenticate can fire off TokenReceived to transition into Connect
Ready,
StartAuthenticate, // should not be string
TokenReceived { token: String }, // AppStart and Authenticate can fire off TokenReceived to transition into Connect
Connected, // Once connected to the socket, and properly authenticated, fire off Connected to transition to WaitForSim
Disconnected, // If for whatever reason we're disconnected from the backend, we need to transition back to Connect
SimConnected, // SimConnect is connected, we're in the world and ready to send data, transition to Running
SimDisconnected, // SimConnect is disconnected, we've finished the flight and exited back to the menu, transition back to WaitForSim
Logout,
Quit,
}
impl State {
pub async fn next(self, event: Event) -> State {
match (self, event) {
match (self.clone(), event.clone()) {
// (Current State, SomeEvent) => NextState
(_, Event::Ready { config }) => State::AppStart { config },
(
State::AppStart { .. },
Event::StartAuthenticate {
open_browser,
code_verifier,
code_challenge_method,
},
) => Self::Authenticate {
open_browser,
code_verifier,
code_challenge_method,
}, // Goto Authenticate
(_, Event::Ready) => State::AppStart,
(_, Event::Logout) => State::AppStart,
(_, Event::StartAuthenticate) => Self::Authenticate, // Goto Authenticate
(State::AppStart { .. }, Event::TokenReceived { token }) => State::Connect { token },
(State::Authenticate { .. }, Event::TokenReceived { token }) => {
State::Connect { token }
}
(_, Event::TokenReceived { token }) => State::Connect { token },
(State::Connect { .. }, Event::Connected) => todo!(), // Goto WaitForSim
(_, Event::Connected) => State::WaitForSim, // Goto WaitForSim
(State::WaitForSim, Event::SimConnected) => todo!(), // Goto InSim
(_, Event::SimConnected) => todo!(), // Goto InSim
(_, Event::Disconnected) => todo!(), // Goto Connect
(State::InSim, Event::SimDisconnected) => todo!(), // Goto WaitForSim
(_, Event::Disconnected) => State::AppStart, // Goto Connect
(_, Event::SimDisconnected) => State::WaitForSim, // Goto WaitForSim
(_, Event::Quit) => todo!(), // All events can go into quit, to shutdown the application
_ => panic!("Invalid state transition"),
}
}
pub async fn run(&self, signal: Sender<Event>) -> Result<(), anyhow::Error> {
pub async fn run(&self, signal: Sender<Event>, config: Config) -> Result<(), anyhow::Error> {
match self {
State::Init => Ok(()),
State::AppStart { config } => {
State::AppStart => {
if let Some(token) = config.token() {
signal.send(Event::TokenReceived {
token: token.to_string(),
})?;
} else {
let open_browser = config.open_browser();
let code_verifier = CodeVerifier::new();
let code_challenge_method = CodeChallengeMethod::Sha256;
config.set_code_verifier(Some(code_verifier.clone()))?;
config.set_code_challenge_method(Some(code_challenge_method.clone()))?;
signal.send(Event::StartAuthenticate {
open_browser,
code_verifier,
code_challenge_method,
})?;
signal.send(Event::StartAuthenticate)?;
}
Ok(())
}
State::Authenticate {
open_browser,
code_verifier,
code_challenge_method,
} => {
if *open_browser {
oauth::open_browser(code_verifier.clone(), code_challenge_method.clone())?;
State::Authenticate => {
if config.open_browser() {
oauth::open_browser(config.clone())?;
}
Ok(())
}
State::Connect { .. } => Ok(()),
State::WaitForSim => Ok(()),
State::WaitForSim => {
tracing::info!("Waiting for sim!");
Ok(())
}
State::InSim => Ok(()),
}
}
@@ -129,25 +92,20 @@ pub async fn start(
) -> Result<(), anyhow::Error> {
let mut state = State::Init;
state.run(event_sender.clone()).await?;
state.run(event_sender.clone(), config.clone()).await?;
loop {
if let Ok(event) = event_receiver.recv().await {
if let Ok(event) = event_receiver.try_recv() {
state = state.next(event.clone()).await;
state.run(event_sender.clone(), config.clone()).await?;
if event == Event::Quit {
tracing::info!("Shutting down State Machine");
break;
}
state = state.next(event).await;
// before run
if let State::Connect { token } = &state {
// before run Connect, save the given token in config
config.set_token(Some(token.clone()))?;
}
state.run(event_sender.clone()).await?;
}
sleep(Duration::from_millis(100)).await;
}
tracing::info!("State Machine Shutdown");

View File

@@ -30,7 +30,7 @@ async fn main() -> anyhow::Result<()> {
let api_service = api::Service::new(postgres.clone(), dangerous_lettre);
let app_state = AppState::new(api_service).await;
let app_state = AppState::new(api_service, config.clone()).await;
HttpServer::new(app_state, postgres.pool())
.await?

View File

@@ -1,5 +1,6 @@
use crate::{
domain::api::models::oauth::*, inbound::http::handlers::oauth::AuthorizationCodeRequest,
domain::api::models::oauth::*,
inbound::http::handlers::oauth::{AuthorizationCodeRequest, VerifyClientAuthorizationRequest},
};
use super::super::models::user::*;
@@ -56,5 +57,13 @@ pub trait ApiService: Clone + Send + Sync + 'static {
fn create_token(
&self,
req: AuthorizationCodeRequest,
) -> impl Future<Output = Result<Option<TokenSubject>, TokenError>> + Send;
) -> impl Future<Output = Result<TokenSubject, TokenError>> + Send;
/// ---
/// WS
/// ---
fn verify_client_authorization(
&self,
req: VerifyClientAuthorizationRequest,
) -> impl Future<Output = Result<User, anyhow::Error>> + Send;
}

View File

@@ -6,6 +6,7 @@ use axum_session::SessionAnySession;
use crate::inbound::http::handlers::oauth::AuthorizationCodeRequest;
use crate::inbound::http::handlers::oauth::GrantType;
use crate::inbound::http::handlers::oauth::VerifyClientAuthorizationRequest;
use super::models::oauth::Client;
use super::models::oauth::*;
@@ -228,7 +229,7 @@ where
async fn create_token(
&self,
req: AuthorizationCodeRequest,
) -> Result<Option<TokenSubject>, TokenError> {
) -> Result<TokenSubject, TokenError> {
if req.grant_type() != GrantType::AuthorizationCode {
return Err(TokenError::InvalidRequest);
}
@@ -265,6 +266,24 @@ where
let _ = self.repo.delete_token(req.code()).await;
Ok(Some(token))
Ok(token)
}
async fn verify_client_authorization(
&self,
req: VerifyClientAuthorizationRequest,
) -> Result<User, anyhow::Error> {
let user_id = req.user_id();
let client_id = req.client_id();
if !self.repo.is_authorized_client(user_id, client_id).await? {
return Err(anyhow::anyhow!("Unauthorized"));
}
let Some(user) = self.repo.find_user_by_id(user_id).await? else {
return Err(anyhow::anyhow!("Unauthorized"));
};
Ok(user)
}
}

View File

@@ -5,14 +5,12 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use anyhow::Context;
use axum::{
extract::ConnectInfo,
routing::{get, post},
Extension,
extract::ConnectInfo, routing::{any, get, post}, Extension
};
use axum_session::{SessionAnyPool, SessionConfig, SessionLayer, SessionStore};
use axum_session_sqlx::SessionPgPool;
use handlers::{
fileserv::file_and_error_handler, leptos::{leptos_routes_handler, server_fn_handler}, oauth, user::activate_account
fileserv::file_and_error_handler, leptos::{leptos_routes_handler, server_fn_handler}, oauth, user::activate_account, websocket::ws_handler
};
use leptos_axum::{generate_route_list, LeptosRoutes};
use state::AppState;
@@ -57,6 +55,7 @@ impl HttpServer {
);
let router = axum::Router::new()
.route("/ws", any(ws_handler))
.nest("/oauth2", oauth::routes())
.route("/auth/activate/:token", get(activate_account))
.route("/api/*fn_name", post(server_fn_handler))
@@ -87,3 +86,4 @@ impl HttpServer {
Ok(())
}
}

View File

@@ -2,3 +2,4 @@ pub mod fileserv;
pub mod leptos;
pub mod oauth;
pub mod user;
pub mod websocket;

View File

@@ -120,6 +120,26 @@ impl AuthorizationCodeRequest {
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VerifyClientAuthorizationRequest {
user_id: uuid::Uuid,
client_id: uuid::Uuid,
}
impl VerifyClientAuthorizationRequest {
pub fn new(user_id: uuid::Uuid, client_id: uuid::Uuid) -> Self {
Self { client_id, user_id }
}
pub fn user_id(&self) -> uuid::Uuid {
self.user_id
}
pub fn client_id(&self) -> uuid::Uuid {
self.client_id
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TokenClaims<T>
where
@@ -180,7 +200,11 @@ where
let iat = now.unix_timestamp() as usize;
let exp = (now + time::Duration::days(30)).unix_timestamp() as usize;
let claims = TokenClaims { sub, iat, exp };
let claims = TokenClaims {
sub: serde_qs::to_string(&sub)?,
iat,
exp,
};
let token = encode(
&Header::default(),

View File

@@ -0,0 +1,225 @@
use std::{borrow::Cow, net::SocketAddr, ops::ControlFlow};
use axum::{
extract::{
ws::{CloseFrame, Message, WebSocket},
ConnectInfo, State, WebSocketUpgrade,
},
response::IntoResponse,
};
use axum_extra::{
headers::{self, authorization::Bearer},
TypedHeader,
};
use futures::{SinkExt, StreamExt};
use http::StatusCode;
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use crate::{
domain::api::{
ports::ApiService,
prelude::{TokenSubject, User},
},
inbound::http::{
handlers::oauth::{TokenClaims, VerifyClientAuthorizationRequest},
state::AppState,
},
};
pub async fn ws_handler<S: ApiService>(
State(app_state): State<AppState<S>>,
ws: WebSocketUpgrade,
auth_token: Option<TypedHeader<headers::Authorization<Bearer>>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> Result<impl IntoResponse, StatusCode> {
let auth_token = match auth_token {
Some(TypedHeader(token)) => Some(token.token().to_string()),
None => return Err(StatusCode::UNAUTHORIZED),
};
let Some(auth_token) = auth_token else {
return Err(StatusCode::UNAUTHORIZED);
};
let jwt_secret = &app_state.config().jwt_secret;
let claims = decode::<TokenClaims<String>>(
&auth_token,
&DecodingKey::from_secret(jwt_secret.as_ref()),
&Validation::new(Algorithm::HS256),
)
.map_err(|e| {
tracing::error!("Unable to decode token: {}\n{:?}", auth_token, e);
StatusCode::UNAUTHORIZED
})?
.claims;
let token_subject: TokenSubject = serde_qs::from_str(&claims.sub).map_err(|e| {
tracing::error!("Unable to parse Token Subject: {}\n{:?}", &claims.sub, e);
StatusCode::BAD_REQUEST
})?;
let user_id = token_subject.user_id();
let client_id = token_subject.client_id();
let app = app_state.api_service();
let Ok(user) = app
.verify_client_authorization(VerifyClientAuthorizationRequest::new(user_id, client_id))
.await
else {
return Err(StatusCode::UNAUTHORIZED);
};
Ok(ws.on_upgrade(move |socket| handle_socket(socket, user, addr)))
}
/// Actual websocket statemachine (one will be spawned per connection)
async fn handle_socket(mut socket: WebSocket, user: User, who: SocketAddr) {
// send a ping (unsupported by some browsers) just to kick things off and get a response
if socket
.send(Message::Text(format!("Hello {}!", user.email())))
.await
.is_ok()
{
tracing::debug!("Pinged {who}...");
} else {
tracing::debug!("Could not send ping {who}!");
// no Error here since the only thing we can do is to close the connection.
// If we can not send messages, there is no way to salvage the statemachine anyway.
return;
}
// receive single message from a client (we can either receive or send with socket).
// this will likely be the Pong for our Ping or a hello message from client.
// waiting for message from a client will block this task, but will not block other client's
// connections.
if let Some(msg) = socket.recv().await {
if let Ok(msg) = msg {
if process_message(msg, who).is_break() {
return;
}
} else {
tracing::debug!("client {who} abruptly disconnected");
return;
}
}
// Since each client gets individual statemachine, we can pause handling
// when necessary to wait for some external event (in this case illustrated by sleeping).
// Waiting for this client to finish getting its greetings does not prevent other clients from
// connecting to server and receiving their greetings.
for i in 1..5 {
if socket
.send(Message::Text(format!("Hi {i} times!")))
.await
.is_err()
{
tracing::debug!("client {who} abruptly disconnected");
return;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
// By splitting socket we can send and receive at the same time. In this example we will send
// unsolicited messages to client based on some sort of server's internal event (i.e .timer).
let (mut sender, mut receiver) = socket.split();
// Spawn a task that will push several messages to the client (does not matter what client does)
let mut send_task = tokio::spawn(async move {
let n_msg = 20;
for i in 0..n_msg {
// In case of any websocket error, we exit.
if sender
.send(Message::Text(format!("Server message {i} ...")))
.await
.is_err()
{
return i;
}
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
}
tracing::debug!("Sending close to {who}...");
if let Err(e) = sender
.send(Message::Close(Some(CloseFrame {
code: axum::extract::ws::close_code::NORMAL,
reason: Cow::from("Goodbye"),
})))
.await
{
tracing::debug!("Could not send Close due to {e}, probably it is ok?");
}
n_msg
});
// This second task will receive messages from client and print them on server console
let mut recv_task = tokio::spawn(async move {
let mut cnt = 0;
while let Some(Ok(msg)) = receiver.next().await {
cnt += 1;
// print message and break if instructed to do so
if process_message(msg, who).is_break() {
break;
}
}
cnt
});
// If any one of the tasks exit, abort the other.
tokio::select! {
rv_a = (&mut send_task) => {
match rv_a {
Ok(a) => tracing::debug!("{a} messages sent to {who}"),
Err(a) => tracing::debug!("Error sending messages {a:?}")
}
recv_task.abort();
},
rv_b = (&mut recv_task) => {
match rv_b {
Ok(b) => tracing::debug!("Received {b} messages"),
Err(b) => tracing::debug!("Error receiving messages {b:?}")
}
send_task.abort();
}
}
// returning from the handler closes the websocket connection
tracing::debug!("Websocket context {who} destroyed");
}
/// helper to print contents of messages to stdout. Has special treatment for Close.
fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> {
match msg {
Message::Text(t) => {
tracing::debug!(">>> {who} sent str: {t:?}");
}
Message::Binary(d) => {
tracing::debug!(">>> {} sent {} bytes: {:?}", who, d.len(), d);
}
Message::Close(c) => {
if let Some(cf) = c {
tracing::debug!(
">>> {} sent close with code {} and reason `{}`",
who,
cf.code,
cf.reason
);
} else {
tracing::debug!(">>> {who} somehow sent close message without CloseFrame");
}
return ControlFlow::Break(());
}
Message::Pong(v) => {
tracing::debug!(">>> {who} sent pong with {v:?}");
}
// You should never need to manually handle Message::Ping, as axum's websocket library
// will do so for you automagically by replying with Pong and copying the v according to
// spec. But if you need the contents of the pings you can see them here.
Message::Ping(v) => {
tracing::debug!(">>> {who} sent ping with {v:?}");
}
}
ControlFlow::Continue(())
}

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use axum::extract::FromRef;
use leptos::get_configuration;
use crate::domain::api::ports::ApiService;
use crate::{config::Config, domain::api::ports::ApiService};
#[derive(Debug, Clone)]
/// The global application state shared between all request handlers.
@@ -12,6 +12,7 @@ where
S: ApiService,
{
pub leptos_options: leptos::LeptosOptions,
config: Arc<Config>,
api_service: Arc<S>,
}
@@ -19,13 +20,18 @@ impl<S> AppState<S>
where
S: ApiService,
{
pub async fn new(api_service: S) -> Self {
pub async fn new(api_service: S, config: Config) -> Self {
Self {
config: Arc::new(config),
leptos_options: get_configuration(None).await.unwrap().leptos_options,
api_service: Arc::new(api_service),
}
}
pub fn config(&self) -> Arc<Config> {
self.config.clone()
}
pub fn api_service(&self) -> Arc<S> {
self.api_service.clone()
}

View File

@@ -1444,10 +1444,18 @@ html {
margin-top: auto;
}
.block {
display: block;
}
.flex {
display: flex;
}
.contents {
display: contents;
}
.hidden {
display: none;
}