websocket
This commit is contained in:
@@ -30,7 +30,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let api_service = api::Service::new(postgres.clone(), dangerous_lettre);
|
||||
|
||||
let app_state = AppState::new(api_service).await;
|
||||
let app_state = AppState::new(api_service, config.clone()).await;
|
||||
|
||||
HttpServer::new(app_state, postgres.pool())
|
||||
.await?
|
||||
|
@@ -1,5 +1,6 @@
|
||||
use crate::{
|
||||
domain::api::models::oauth::*, inbound::http::handlers::oauth::AuthorizationCodeRequest,
|
||||
domain::api::models::oauth::*,
|
||||
inbound::http::handlers::oauth::{AuthorizationCodeRequest, VerifyClientAuthorizationRequest},
|
||||
};
|
||||
|
||||
use super::super::models::user::*;
|
||||
@@ -56,5 +57,13 @@ pub trait ApiService: Clone + Send + Sync + 'static {
|
||||
fn create_token(
|
||||
&self,
|
||||
req: AuthorizationCodeRequest,
|
||||
) -> impl Future<Output = Result<Option<TokenSubject>, TokenError>> + Send;
|
||||
) -> impl Future<Output = Result<TokenSubject, TokenError>> + Send;
|
||||
|
||||
/// ---
|
||||
/// WS
|
||||
/// ---
|
||||
fn verify_client_authorization(
|
||||
&self,
|
||||
req: VerifyClientAuthorizationRequest,
|
||||
) -> impl Future<Output = Result<User, anyhow::Error>> + Send;
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@ use axum_session::SessionAnySession;
|
||||
|
||||
use crate::inbound::http::handlers::oauth::AuthorizationCodeRequest;
|
||||
use crate::inbound::http::handlers::oauth::GrantType;
|
||||
use crate::inbound::http::handlers::oauth::VerifyClientAuthorizationRequest;
|
||||
|
||||
use super::models::oauth::Client;
|
||||
use super::models::oauth::*;
|
||||
@@ -228,7 +229,7 @@ where
|
||||
async fn create_token(
|
||||
&self,
|
||||
req: AuthorizationCodeRequest,
|
||||
) -> Result<Option<TokenSubject>, TokenError> {
|
||||
) -> Result<TokenSubject, TokenError> {
|
||||
if req.grant_type() != GrantType::AuthorizationCode {
|
||||
return Err(TokenError::InvalidRequest);
|
||||
}
|
||||
@@ -265,6 +266,24 @@ where
|
||||
|
||||
let _ = self.repo.delete_token(req.code()).await;
|
||||
|
||||
Ok(Some(token))
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
async fn verify_client_authorization(
|
||||
&self,
|
||||
req: VerifyClientAuthorizationRequest,
|
||||
) -> Result<User, anyhow::Error> {
|
||||
let user_id = req.user_id();
|
||||
let client_id = req.client_id();
|
||||
|
||||
if !self.repo.is_authorized_client(user_id, client_id).await? {
|
||||
return Err(anyhow::anyhow!("Unauthorized"));
|
||||
}
|
||||
|
||||
let Some(user) = self.repo.find_user_by_id(user_id).await? else {
|
||||
return Err(anyhow::anyhow!("Unauthorized"));
|
||||
};
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
}
|
||||
|
@@ -5,14 +5,12 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
|
||||
use anyhow::Context;
|
||||
use axum::{
|
||||
extract::ConnectInfo,
|
||||
routing::{get, post},
|
||||
Extension,
|
||||
extract::ConnectInfo, routing::{any, get, post}, Extension
|
||||
};
|
||||
use axum_session::{SessionAnyPool, SessionConfig, SessionLayer, SessionStore};
|
||||
use axum_session_sqlx::SessionPgPool;
|
||||
use handlers::{
|
||||
fileserv::file_and_error_handler, leptos::{leptos_routes_handler, server_fn_handler}, oauth, user::activate_account
|
||||
fileserv::file_and_error_handler, leptos::{leptos_routes_handler, server_fn_handler}, oauth, user::activate_account, websocket::ws_handler
|
||||
};
|
||||
use leptos_axum::{generate_route_list, LeptosRoutes};
|
||||
use state::AppState;
|
||||
@@ -57,6 +55,7 @@ impl HttpServer {
|
||||
);
|
||||
|
||||
let router = axum::Router::new()
|
||||
.route("/ws", any(ws_handler))
|
||||
.nest("/oauth2", oauth::routes())
|
||||
.route("/auth/activate/:token", get(activate_account))
|
||||
.route("/api/*fn_name", post(server_fn_handler))
|
||||
@@ -87,3 +86,4 @@ impl HttpServer {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -2,3 +2,4 @@ pub mod fileserv;
|
||||
pub mod leptos;
|
||||
pub mod oauth;
|
||||
pub mod user;
|
||||
pub mod websocket;
|
||||
|
@@ -120,6 +120,26 @@ impl AuthorizationCodeRequest {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct VerifyClientAuthorizationRequest {
|
||||
user_id: uuid::Uuid,
|
||||
client_id: uuid::Uuid,
|
||||
}
|
||||
|
||||
impl VerifyClientAuthorizationRequest {
|
||||
pub fn new(user_id: uuid::Uuid, client_id: uuid::Uuid) -> Self {
|
||||
Self { client_id, user_id }
|
||||
}
|
||||
|
||||
pub fn user_id(&self) -> uuid::Uuid {
|
||||
self.user_id
|
||||
}
|
||||
|
||||
pub fn client_id(&self) -> uuid::Uuid {
|
||||
self.client_id
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct TokenClaims<T>
|
||||
where
|
||||
@@ -180,7 +200,11 @@ where
|
||||
let iat = now.unix_timestamp() as usize;
|
||||
let exp = (now + time::Duration::days(30)).unix_timestamp() as usize;
|
||||
|
||||
let claims = TokenClaims { sub, iat, exp };
|
||||
let claims = TokenClaims {
|
||||
sub: serde_qs::to_string(&sub)?,
|
||||
iat,
|
||||
exp,
|
||||
};
|
||||
|
||||
let token = encode(
|
||||
&Header::default(),
|
||||
|
225
src/lib/inbound/http/handlers/websocket.rs
Normal file
225
src/lib/inbound/http/handlers/websocket.rs
Normal file
@@ -0,0 +1,225 @@
|
||||
use std::{borrow::Cow, net::SocketAddr, ops::ControlFlow};
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{CloseFrame, Message, WebSocket},
|
||||
ConnectInfo, State, WebSocketUpgrade,
|
||||
},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use axum_extra::{
|
||||
headers::{self, authorization::Bearer},
|
||||
TypedHeader,
|
||||
};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use http::StatusCode;
|
||||
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
|
||||
|
||||
use crate::{
|
||||
domain::api::{
|
||||
ports::ApiService,
|
||||
prelude::{TokenSubject, User},
|
||||
},
|
||||
inbound::http::{
|
||||
handlers::oauth::{TokenClaims, VerifyClientAuthorizationRequest},
|
||||
state::AppState,
|
||||
},
|
||||
};
|
||||
|
||||
pub async fn ws_handler<S: ApiService>(
|
||||
State(app_state): State<AppState<S>>,
|
||||
ws: WebSocketUpgrade,
|
||||
auth_token: Option<TypedHeader<headers::Authorization<Bearer>>>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let auth_token = match auth_token {
|
||||
Some(TypedHeader(token)) => Some(token.token().to_string()),
|
||||
None => return Err(StatusCode::UNAUTHORIZED),
|
||||
};
|
||||
|
||||
let Some(auth_token) = auth_token else {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
};
|
||||
|
||||
let jwt_secret = &app_state.config().jwt_secret;
|
||||
|
||||
let claims = decode::<TokenClaims<String>>(
|
||||
&auth_token,
|
||||
&DecodingKey::from_secret(jwt_secret.as_ref()),
|
||||
&Validation::new(Algorithm::HS256),
|
||||
)
|
||||
.map_err(|e| {
|
||||
tracing::error!("Unable to decode token: {}\n{:?}", auth_token, e);
|
||||
StatusCode::UNAUTHORIZED
|
||||
})?
|
||||
.claims;
|
||||
|
||||
let token_subject: TokenSubject = serde_qs::from_str(&claims.sub).map_err(|e| {
|
||||
tracing::error!("Unable to parse Token Subject: {}\n{:?}", &claims.sub, e);
|
||||
StatusCode::BAD_REQUEST
|
||||
})?;
|
||||
|
||||
let user_id = token_subject.user_id();
|
||||
let client_id = token_subject.client_id();
|
||||
|
||||
let app = app_state.api_service();
|
||||
let Ok(user) = app
|
||||
.verify_client_authorization(VerifyClientAuthorizationRequest::new(user_id, client_id))
|
||||
.await
|
||||
else {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
};
|
||||
|
||||
Ok(ws.on_upgrade(move |socket| handle_socket(socket, user, addr)))
|
||||
}
|
||||
|
||||
/// Actual websocket statemachine (one will be spawned per connection)
|
||||
async fn handle_socket(mut socket: WebSocket, user: User, who: SocketAddr) {
|
||||
// send a ping (unsupported by some browsers) just to kick things off and get a response
|
||||
if socket
|
||||
.send(Message::Text(format!("Hello {}!", user.email())))
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
tracing::debug!("Pinged {who}...");
|
||||
} else {
|
||||
tracing::debug!("Could not send ping {who}!");
|
||||
// no Error here since the only thing we can do is to close the connection.
|
||||
// If we can not send messages, there is no way to salvage the statemachine anyway.
|
||||
return;
|
||||
}
|
||||
|
||||
// receive single message from a client (we can either receive or send with socket).
|
||||
// this will likely be the Pong for our Ping or a hello message from client.
|
||||
// waiting for message from a client will block this task, but will not block other client's
|
||||
// connections.
|
||||
if let Some(msg) = socket.recv().await {
|
||||
if let Ok(msg) = msg {
|
||||
if process_message(msg, who).is_break() {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
tracing::debug!("client {who} abruptly disconnected");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Since each client gets individual statemachine, we can pause handling
|
||||
// when necessary to wait for some external event (in this case illustrated by sleeping).
|
||||
// Waiting for this client to finish getting its greetings does not prevent other clients from
|
||||
// connecting to server and receiving their greetings.
|
||||
for i in 1..5 {
|
||||
if socket
|
||||
.send(Message::Text(format!("Hi {i} times!")))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
tracing::debug!("client {who} abruptly disconnected");
|
||||
return;
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
// By splitting socket we can send and receive at the same time. In this example we will send
|
||||
// unsolicited messages to client based on some sort of server's internal event (i.e .timer).
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Spawn a task that will push several messages to the client (does not matter what client does)
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
let n_msg = 20;
|
||||
for i in 0..n_msg {
|
||||
// In case of any websocket error, we exit.
|
||||
if sender
|
||||
.send(Message::Text(format!("Server message {i} ...")))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return i;
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
|
||||
}
|
||||
|
||||
tracing::debug!("Sending close to {who}...");
|
||||
if let Err(e) = sender
|
||||
.send(Message::Close(Some(CloseFrame {
|
||||
code: axum::extract::ws::close_code::NORMAL,
|
||||
reason: Cow::from("Goodbye"),
|
||||
})))
|
||||
.await
|
||||
{
|
||||
tracing::debug!("Could not send Close due to {e}, probably it is ok?");
|
||||
}
|
||||
n_msg
|
||||
});
|
||||
|
||||
// This second task will receive messages from client and print them on server console
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
let mut cnt = 0;
|
||||
while let Some(Ok(msg)) = receiver.next().await {
|
||||
cnt += 1;
|
||||
// print message and break if instructed to do so
|
||||
if process_message(msg, who).is_break() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
cnt
|
||||
});
|
||||
|
||||
// If any one of the tasks exit, abort the other.
|
||||
tokio::select! {
|
||||
rv_a = (&mut send_task) => {
|
||||
match rv_a {
|
||||
Ok(a) => tracing::debug!("{a} messages sent to {who}"),
|
||||
Err(a) => tracing::debug!("Error sending messages {a:?}")
|
||||
}
|
||||
recv_task.abort();
|
||||
},
|
||||
rv_b = (&mut recv_task) => {
|
||||
match rv_b {
|
||||
Ok(b) => tracing::debug!("Received {b} messages"),
|
||||
Err(b) => tracing::debug!("Error receiving messages {b:?}")
|
||||
}
|
||||
send_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
// returning from the handler closes the websocket connection
|
||||
tracing::debug!("Websocket context {who} destroyed");
|
||||
}
|
||||
|
||||
/// helper to print contents of messages to stdout. Has special treatment for Close.
|
||||
fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> {
|
||||
match msg {
|
||||
Message::Text(t) => {
|
||||
tracing::debug!(">>> {who} sent str: {t:?}");
|
||||
}
|
||||
Message::Binary(d) => {
|
||||
tracing::debug!(">>> {} sent {} bytes: {:?}", who, d.len(), d);
|
||||
}
|
||||
Message::Close(c) => {
|
||||
if let Some(cf) = c {
|
||||
tracing::debug!(
|
||||
">>> {} sent close with code {} and reason `{}`",
|
||||
who,
|
||||
cf.code,
|
||||
cf.reason
|
||||
);
|
||||
} else {
|
||||
tracing::debug!(">>> {who} somehow sent close message without CloseFrame");
|
||||
}
|
||||
return ControlFlow::Break(());
|
||||
}
|
||||
|
||||
Message::Pong(v) => {
|
||||
tracing::debug!(">>> {who} sent pong with {v:?}");
|
||||
}
|
||||
// You should never need to manually handle Message::Ping, as axum's websocket library
|
||||
// will do so for you automagically by replying with Pong and copying the v according to
|
||||
// spec. But if you need the contents of the pings you can see them here.
|
||||
Message::Ping(v) => {
|
||||
tracing::debug!(">>> {who} sent ping with {v:?}");
|
||||
}
|
||||
}
|
||||
ControlFlow::Continue(())
|
||||
}
|
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use axum::extract::FromRef;
|
||||
use leptos::get_configuration;
|
||||
|
||||
use crate::domain::api::ports::ApiService;
|
||||
use crate::{config::Config, domain::api::ports::ApiService};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// The global application state shared between all request handlers.
|
||||
@@ -12,6 +12,7 @@ where
|
||||
S: ApiService,
|
||||
{
|
||||
pub leptos_options: leptos::LeptosOptions,
|
||||
config: Arc<Config>,
|
||||
api_service: Arc<S>,
|
||||
}
|
||||
|
||||
@@ -19,13 +20,18 @@ impl<S> AppState<S>
|
||||
where
|
||||
S: ApiService,
|
||||
{
|
||||
pub async fn new(api_service: S) -> Self {
|
||||
pub async fn new(api_service: S, config: Config) -> Self {
|
||||
Self {
|
||||
config: Arc::new(config),
|
||||
leptos_options: get_configuration(None).await.unwrap().leptos_options,
|
||||
api_service: Arc::new(api_service),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn config(&self) -> Arc<Config> {
|
||||
self.config.clone()
|
||||
}
|
||||
|
||||
pub fn api_service(&self) -> Arc<S> {
|
||||
self.api_service.clone()
|
||||
}
|
||||
|
Reference in New Issue
Block a user