avam-client and oauth2

This commit is contained in:
2024-10-17 00:56:02 +02:00
parent bfc5cbf624
commit f93eb3c429
50 changed files with 5674 additions and 277 deletions

View File

@@ -1,4 +1,5 @@
#[cfg(feature = "ssr")]
#[allow(clippy::needless_return)]
#[tokio::main]
async fn main() -> anyhow::Result<()> {
use avam::config::Config;
@@ -31,13 +32,14 @@ async fn main() -> anyhow::Result<()> {
let app_state = AppState::new(api_service).await;
let http_server = HttpServer::new(app_state, postgres.pool()).await?;
http_server.run().await
HttpServer::new(app_state, postgres.pool())
.await?
.run()
.await
}
#[cfg(not(feature = "ssr"))]
pub fn main() {
println!("Do run this main?!");
// no client-side main function
// unless we want this to work with e.g., Trunk for a purely client-side app
// see lib.rs for hydration function instead

View File

@@ -8,6 +8,7 @@ const SMTP_PORT: &str = "SMTP_PORT";
const SMTP_USERNAME: &str = "SMTP_USERNAME";
const SMTP_PASSWORD: &str = "SMTP_PASSWORD";
const SMTP_SENDER: &str = "SMTP_SENDER";
const JWT_SECRET: &str = "JWT_SECRET";
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Config {
@@ -18,6 +19,8 @@ pub struct Config {
pub smtp_username: String,
pub smtp_password: String,
pub smtp_sender: String,
pub jwt_secret: String,
}
impl Config {
@@ -28,6 +31,7 @@ impl Config {
let smtp_username = load_env(SMTP_USERNAME)?;
let smtp_password = load_env(SMTP_PASSWORD)?;
let smtp_sender = load_env(SMTP_SENDER)?;
let jwt_secret = load_env(JWT_SECRET)?;
Ok(Config {
database_url,
@@ -36,6 +40,7 @@ impl Config {
smtp_username,
smtp_password,
smtp_sender,
jwt_secret,
})
}
}

View File

@@ -18,6 +18,7 @@ pub mod prelude {
// But so far, this is the only thing that actually works
pub type AppService = std::sync::Arc<Service<Postgres, DangerousLettre>>;
pub use super::models::oauth::*;
pub use super::models::user::*;
pub use super::ports::*;
pub use super::service::*;
@@ -32,6 +33,7 @@ pub mod prelude {
#[cfg(not(feature = "ssr"))]
pub mod prelude {
pub use super::models::oauth::*;
pub use super::models::user::*;
pub use crate::domain::leptos::flashbag::Alert;
pub use crate::domain::leptos::flashbag::Flash;

View File

@@ -1 +1,3 @@
pub mod oauth;
pub mod pilot;
pub mod user;

View File

@@ -0,0 +1,377 @@
use std::{fmt::Display, str::FromStr};
use derive_more::derive::{Display, From};
#[cfg(feature = "ssr")]
use rand::{distributions::Alphanumeric, Rng};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use super::user::User;
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct Client {
id: uuid::Uuid,
user_id: uuid::Uuid,
name: ClientName,
secret: ClientSecret,
redirect_uri: RedirectUri,
}
impl Client {
pub fn new(
id: uuid::Uuid,
user_id: uuid::Uuid,
name: ClientName,
secret: ClientSecret,
redirect_uri: RedirectUri,
) -> Self {
Self {
id,
user_id,
name,
secret,
redirect_uri,
}
}
pub fn id(&self) -> uuid::Uuid {
self.id
}
pub fn name(&self) -> &ClientName {
&self.name
}
pub fn redirect_uri(&self) -> &RedirectUri {
&self.redirect_uri
}
}
#[derive(Clone, Debug, Display, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct ClientName(String);
impl ClientName {
pub fn new(name: &str) -> Self {
Self(name.to_string())
}
}
#[derive(Clone, Display, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(from = "String")]
pub struct ClientSecret(String);
impl From<String> for ClientSecret {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<&str> for ClientSecret {
fn from(value: &str) -> Self {
Self(value.to_string())
}
}
#[cfg(feature = "ssr")]
impl ClientSecret {
pub fn new() -> Self {
let token: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect();
Self(token)
}
}
#[cfg(feature = "ssr")]
impl Default for ClientSecret {
fn default() -> Self {
Self::new()
}
}
#[derive(
Clone, Debug, Display, From, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize,
)]
pub struct RedirectUri(String);
impl RedirectUri {
pub fn new(uri: &str) -> Self {
Self(uri.to_string())
}
}
#[cfg(feature = "ssr")]
#[derive(Debug, Error)]
pub enum CreateAuthorizationCodeError {
#[error(transparent)]
Unknown(#[from] anyhow::Error),
// to be extended as new error scenarios are introduced
}
#[derive(Clone, Display, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(from = "String")]
pub struct AuthorizationCode(String);
impl From<String> for AuthorizationCode {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<&str> for AuthorizationCode {
fn from(value: &str) -> Self {
Self(value.to_string())
}
}
#[cfg(feature = "ssr")]
impl AuthorizationCode {
pub fn new() -> Self {
let token: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(64)
.map(char::from)
.collect();
Self(token)
}
}
#[cfg(feature = "ssr")]
impl Default for AuthorizationCode {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct AuthorizedClient {
client: Client,
user: User,
}
#[derive(Debug, Error)]
pub enum CodeChallengeMethodError {
#[error("Code challenge method is not valid.")]
Invalid,
#[error(transparent)]
Unknown(#[from] anyhow::Error),
// to be extended as new error scenarios are introduced
}
#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum CodeChallengeMethod {
#[default]
#[serde(rename = "plain")]
Plain,
#[serde(rename = "S256")]
Sha256,
}
impl Display for CodeChallengeMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
CodeChallengeMethod::Plain => "plain",
CodeChallengeMethod::Sha256 => "S256",
}
)
}
}
impl FromStr for CodeChallengeMethod {
type Err = crate::domain::api::models::oauth::CodeChallengeMethodError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"plain" => Ok(Self::Plain),
"S256" => Ok(Self::Sha256),
_ => Err(crate::domain::api::models::oauth::CodeChallengeMethodError::Invalid),
}
}
}
impl TryFrom<String> for CodeChallengeMethod {
type Error = crate::domain::api::models::oauth::CodeChallengeMethodError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::from_str(&value)
}
}
#[derive(Debug, Error)]
pub enum ResponseTypeError {
#[error("The response type is not valid.")]
Invalid,
#[error(transparent)]
Unknown(#[from] anyhow::Error),
// to be extended as new error scenarios are introduced
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum ResponseType {
#[serde(rename = "code")]
Code,
}
impl Display for ResponseType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
ResponseType::Code => "code",
}
)
}
}
impl FromStr for ResponseType {
type Err = crate::domain::api::models::oauth::ResponseTypeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"code" => Ok(Self::Code),
_ => Err(crate::domain::api::models::oauth::ResponseTypeError::Invalid),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthorizeRequest {
client_id: uuid::Uuid,
response_type: ResponseType, // Make type (enum:code,)
state: Option<String>, // random string for CSRF protection
code_challenge: String, // pkce
code_challenge_method: Option<CodeChallengeMethod>, // Make type (enum:sha256,) hashing algo
redirect_uri: RedirectUri, // Make type
scope: Option<String>, // space seperated string with permissions
}
impl AuthorizeRequest {
pub fn new(
client_id: uuid::Uuid,
response_type: ResponseType,
state: Option<String>,
code_challenge: String,
code_challenge_method: Option<CodeChallengeMethod>,
redirect_uri: RedirectUri,
scope: Option<String>,
) -> Self {
Self {
client_id,
response_type,
state,
code_challenge,
code_challenge_method,
redirect_uri,
scope,
}
}
pub fn client_id(&self) -> uuid::Uuid {
self.client_id
}
pub fn response_type(&self) -> ResponseType {
self.response_type.clone()
}
pub fn state(&self) -> Option<String> {
self.state.clone()
}
pub fn code_challenge(&self) -> String {
self.code_challenge.clone()
}
pub fn code_challenge_method(&self) -> Option<CodeChallengeMethod> {
self.code_challenge_method.clone()
}
pub fn redirect_uri(&self) -> RedirectUri {
self.redirect_uri.clone()
}
pub fn scope(&self) -> Option<String> {
self.scope.clone()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthorizationResponse {
code: AuthorizationCode,
state: Option<String>,
}
impl AuthorizationResponse {
pub fn new(code: AuthorizationCode, state: Option<String>) -> Self {
Self { code, state }
}
pub fn code(&self) -> AuthorizationCode {
self.code.clone()
}
pub fn state(&self) -> Option<String> {
self.state.clone()
}
}
#[derive(Debug, Error)]
pub enum TokenError {
#[error("Invalid Token Request")]
InvalidRequest,
#[error(transparent)]
Unknown(#[from] anyhow::Error),
// to be extended as new error scenarios are introduced
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TokenSubject {
#[serde(rename = "a")]
user_id: uuid::Uuid,
#[serde(rename = "b")]
client_id: uuid::Uuid,
#[serde(skip)]
code_challenge: String,
#[serde(skip)]
code_challenge_method: CodeChallengeMethod,
}
impl TokenSubject {
pub fn new(
user_id: uuid::Uuid,
client_id: uuid::Uuid,
code_challenge: String,
code_challenge_method: CodeChallengeMethod,
) -> Self {
Self {
user_id,
client_id,
code_challenge,
code_challenge_method,
}
}
pub fn user_id(&self) -> uuid::Uuid {
self.user_id
}
pub fn client_id(&self) -> uuid::Uuid {
self.client_id
}
pub fn code_challenge(&self) -> String {
self.code_challenge.clone()
}
pub fn code_challenge_method(&self) -> CodeChallengeMethod {
self.code_challenge_method.clone()
}
}

View File

View File

@@ -40,8 +40,8 @@ impl User {
}
}
pub fn id(&self) -> &uuid::Uuid {
&self.id
pub fn id(&self) -> uuid::Uuid {
self.id
}
pub fn email(&self) -> &EmailAddress {
@@ -103,6 +103,13 @@ impl ActivationToken {
}
}
#[cfg(feature = "ssr")]
impl Default for ActivationToken {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "ssr")]
#[derive(Clone, Display, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize)]
#[serde(from = "String")]
@@ -135,6 +142,13 @@ impl PasswordResetToken {
}
}
#[cfg(feature = "ssr")]
impl Default for PasswordResetToken {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "ssr")]
#[derive(Debug, Error)]
pub enum ResetPasswordError {
@@ -410,6 +424,13 @@ impl UpdateUserRequest {
}
}
#[cfg(feature = "ssr")]
impl Default for UpdateUserRequest {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "ssr")]
#[derive(Debug, Error)]
pub enum UpdateUserError {

View File

@@ -1,133 +1,9 @@
/*
Module `ports` specifies the API by which external modules interact with the user domain.
mod api_service;
mod oauth_repository;
mod user_notifier;
mod user_repository;
All traits are bounded by `Send + Sync + 'static`, since their implementations must be shareable
between request-handling threads.
Trait methods are explicitly asynchronous, including `Send` bounds on response types,
since the application is expected to always run in a multithreaded environment.
*/
use std::future::Future;
use super::models::user::*;
/// `ApiService` is the public API for the user domain.
///
/// External modules must conform to this contract the domain is not concerned with the
/// implementation details or underlying technology of any external code.
pub trait ApiService: Clone + Send + Sync + 'static {
/// Asynchronously create a new [User].
///
/// # Errors
///
/// - [CreateUserError::Duplicate] if an [User] with the same [EmailAddress] already exists.
fn create_user(
&self,
req: CreateUserRequest,
) -> impl Future<Output = Result<User, CreateUserError>> + Send;
// fn activate_user
fn get_user_session(
&self,
session: &axum_session::SessionAnySession, // TODO: Get rid of this and make cleaner
) -> impl Future<Output = Option<User>> + Send;
fn activate_user_account(
&self,
token: ActivationToken,
) -> impl Future<Output = Result<User, ActivateUserError>> + Send;
fn user_login(
&self,
req: UserLoginRequest,
) -> impl Future<Output = Result<User, UserLoginError>> + Send;
fn forgot_password(&self, email: &EmailAddress) -> impl Future<Output = ()> + Send;
fn reset_password(
&self,
token: &PasswordResetToken,
password: &Password,
) -> impl Future<Output = Result<User, ResetPasswordError>> + Send;
fn find_user_by_password_reset_token(
&self,
token: &PasswordResetToken,
) -> impl Future<Output = Option<User>> + Send;
// These shouldnt be here, _why_ are they here, and implement that here instead
// fn find_user_by_email(self, email: EmailAddress) -> impl Future<Output = Option<User>> + Send;
// fn find_user_by_id(&self, user_id: uuid::Uuid) -> impl Future<Output = Option<User>> + Send;
}
pub trait UserRepository: Clone + Send + Sync + 'static {
// Create
fn create_user(
&self,
req: CreateUserRequest,
) -> impl Future<Output = Result<User, CreateUserError>> + Send;
fn create_activation_token(
&self,
ent: &User,
) -> impl Future<Output = Result<ActivationToken, anyhow::Error>> + Send;
fn create_password_reset_token(
&self,
ent: &User,
) -> impl Future<Output = Result<PasswordResetToken, anyhow::Error>> + Send;
// Read
fn all_users(&self) -> impl Future<Output = Vec<User>> + Send;
fn find_user_by_id(
&self,
id: uuid::Uuid,
) -> impl Future<Output = Result<Option<User>, anyhow::Error>> + Send;
fn find_user_by_email(
&self,
email: &EmailAddress,
) -> impl Future<Output = Result<Option<User>, anyhow::Error>> + Send;
fn find_user_by_activation_token(
&self,
token: &ActivationToken,
) -> impl Future<Output = Result<Option<User>, anyhow::Error>> + Send;
fn find_user_by_password_reset_token(
&self,
token: &PasswordResetToken,
) -> impl Future<Output = Result<Option<User>, anyhow::Error>> + Send;
// // Update
fn update_user(
&self,
ent: &User,
req: UpdateUserRequest,
) -> impl Future<Output = Result<(User, User), UpdateUserError>> + Send;
// Delete
// fn delete_user(&self, ent: User) -> impl Future<Output = Result<User, DeleteUserError>> + Send;
fn delete_activation_token_for_user(
&self,
ent: &User,
) -> impl Future<Output = Result<(), anyhow::Error>> + Send;
fn delete_password_reset_tokens_for_user(
&self,
ent: &User,
) -> impl Future<Output = Result<(), anyhow::Error>> + Send;
}
pub trait UserNotifier: Clone + Send + Sync + 'static {
fn user_created(&self, user: &User, token: &ActivationToken)
-> impl Future<Output = ()> + Send;
fn forgot_password(
&self,
user: &User,
token: &PasswordResetToken,
) -> impl Future<Output = ()> + Send;
}
pub use api_service::ApiService;
pub use oauth_repository::OAuthRepository;
pub use user_notifier::UserNotifier;
pub use user_repository::UserRepository;

View File

@@ -0,0 +1,60 @@
use crate::{
domain::api::models::oauth::*, inbound::http::handlers::oauth::AuthorizationCodeRequest,
};
use super::super::models::user::*;
use std::future::Future;
pub trait ApiService: Clone + Send + Sync + 'static {
// ---
// USER
// ---
fn create_user(
&self,
req: CreateUserRequest,
) -> impl Future<Output = Result<User, CreateUserError>> + Send;
fn get_user_session(
&self,
session: &axum_session::SessionAnySession, // TODO: Get rid of this and make cleaner
) -> impl Future<Output = Option<User>> + Send;
fn activate_user_account(
&self,
token: ActivationToken,
) -> impl Future<Output = Result<User, ActivateUserError>> + Send;
fn user_login(
&self,
req: UserLoginRequest,
) -> impl Future<Output = Result<User, UserLoginError>> + Send;
fn forgot_password(&self, email: &EmailAddress) -> impl Future<Output = ()> + Send;
fn reset_password(
&self,
token: &PasswordResetToken,
password: &Password,
) -> impl Future<Output = Result<User, ResetPasswordError>> + Send;
fn find_user_by_password_reset_token(
&self,
token: &PasswordResetToken,
) -> impl Future<Output = Option<User>> + Send;
// ---
// OAUTH
// ---
fn find_client_by_id(&self, id: uuid::Uuid) -> impl Future<Output = Option<Client>> + Send;
fn generate_authorization_code(
&self,
user: &User,
req: AuthorizeRequest,
) -> impl Future<Output = Result<AuthorizationResponse, anyhow::Error>> + Send;
fn create_token(
&self,
req: AuthorizationCodeRequest,
) -> impl Future<Output = Result<Option<TokenSubject>, TokenError>> + Send;
}

View File

@@ -0,0 +1,33 @@
use super::super::models::oauth::*;
use std::future::Future;
pub trait OAuthRepository: Clone + Send + Sync + 'static {
fn find_client_by_id(
&self,
id: uuid::Uuid,
) -> impl Future<Output = Result<Option<Client>, anyhow::Error>> + Send;
fn create_authorization_code(
&self,
user_id: uuid::Uuid,
client_id: uuid::Uuid,
code_challenge: String,
code_challenge_method: CodeChallengeMethod,
) -> impl Future<Output = Result<AuthorizationCode, anyhow::Error>> + Send;
fn is_authorized_client(
&self,
user_id: uuid::Uuid,
client_id: uuid::Uuid,
) -> impl Future<Output = Result<bool, anyhow::Error>> + Send;
fn get_token_subject(
&self,
code: AuthorizationCode,
) -> impl Future<Output = Result<Option<TokenSubject>, anyhow::Error>> + Send;
fn delete_token(
&self,
code: AuthorizationCode,
) -> impl Future<Output = Result<(), anyhow::Error>> + Send;
}

View File

@@ -0,0 +1,12 @@
use super::super::models::user::*;
use std::future::Future;
pub trait UserNotifier: Clone + Send + Sync + 'static {
fn user_created(&self, user: &User, token: &ActivationToken)
-> impl Future<Output = ()> + Send;
fn forgot_password(
&self,
user: &User,
token: &PasswordResetToken,
) -> impl Future<Output = ()> + Send;
}

View File

@@ -0,0 +1,62 @@
use super::super::models::user::*;
use std::future::Future;
pub trait UserRepository: Clone + Send + Sync + 'static {
// Create
fn create_user(
&self,
req: CreateUserRequest,
) -> impl Future<Output = Result<User, CreateUserError>> + Send;
fn create_activation_token(
&self,
ent: &User,
) -> impl Future<Output = Result<ActivationToken, anyhow::Error>> + Send;
fn create_password_reset_token(
&self,
ent: &User,
) -> impl Future<Output = Result<PasswordResetToken, anyhow::Error>> + Send;
// Read
fn all_users(&self) -> impl Future<Output = Vec<User>> + Send;
fn find_user_by_id(
&self,
id: uuid::Uuid,
) -> impl Future<Output = Result<Option<User>, anyhow::Error>> + Send;
fn find_user_by_email(
&self,
email: &EmailAddress,
) -> impl Future<Output = Result<Option<User>, anyhow::Error>> + Send;
fn find_user_by_activation_token(
&self,
token: &ActivationToken,
) -> impl Future<Output = Result<Option<User>, anyhow::Error>> + Send;
fn find_user_by_password_reset_token(
&self,
token: &PasswordResetToken,
) -> impl Future<Output = Result<Option<User>, anyhow::Error>> + Send;
// // Update
fn update_user(
&self,
ent: &User,
req: UpdateUserRequest,
) -> impl Future<Output = Result<(User, User), UpdateUserError>> + Send;
// Delete
// fn delete_user(&self, ent: User) -> impl Future<Output = Result<User, DeleteUserError>> + Send;
fn delete_activation_token_for_user(
&self,
ent: &User,
) -> impl Future<Output = Result<(), anyhow::Error>> + Send;
fn delete_password_reset_tokens_for_user(
&self,
ent: &User,
) -> impl Future<Output = Result<(), anyhow::Error>> + Send;
}

View File

@@ -4,17 +4,22 @@
*/
use axum_session::SessionAnySession;
use crate::inbound::http::handlers::oauth::AuthorizationCodeRequest;
use crate::inbound::http::handlers::oauth::GrantType;
use super::models::oauth::Client;
use super::models::oauth::*;
use super::models::user::*;
use super::ports::{ApiService, UserNotifier, UserRepository};
use super::ports::{ApiService, OAuthRepository, UserNotifier, UserRepository};
pub trait Repository = UserRepository;
// pub trait Repository = UserRepository + OAuthRepository;
pub trait Email = UserNotifier;
#[derive(Debug, Clone)]
pub struct Service<R, N>
where
R: Repository,
R: UserRepository + OAuthRepository,
N: Email,
{
repo: R,
@@ -23,7 +28,7 @@ where
impl<R, N> Service<R, N>
where
R: Repository,
R: UserRepository + OAuthRepository,
N: Email,
{
pub fn new(repo: R, notifier: N) -> Self {
@@ -33,12 +38,13 @@ where
impl<R, N> ApiService for Service<R, N>
where
R: UserRepository,
R: UserRepository + OAuthRepository,
N: Email,
{
async fn create_user(&self, req: CreateUserRequest) -> Result<User, CreateUserError> {
let result = self.repo.create_user(req).await;
#[allow(clippy::question_mark)]
if result.is_err() {
// something went wrong, log the error
// but keep passing on the result to the requester (http server)
@@ -56,9 +62,7 @@ where
}
async fn get_user_session(&self, session: &SessionAnySession) -> Option<User> {
let Some(user_id) = session.get("user") else {
return None;
};
let user_id = session.get("user")?;
self.repo.find_user_by_id(user_id).await.unwrap_or(None)
}
@@ -69,7 +73,7 @@ where
) -> Result<User, ActivateUserError> {
let user = match self.repo.find_user_by_activation_token(&token).await {
Ok(u) => u,
Err(e) => return Err(ActivateUserError::Unknown(e.into())),
Err(e) => return Err(ActivateUserError::Unknown(e)),
};
let Some(user) = user else {
@@ -98,7 +102,7 @@ where
Ok(u) => u,
Err(e) => {
tracing::error!("{:#?}", e);
return Err(UserLoginError::Unknown(e.into()));
return Err(UserLoginError::Unknown(e));
}
};
@@ -162,7 +166,7 @@ where
Ok(u) => u,
Err(e) => {
tracing::error!("{:#?}", e);
return Err(ResetPasswordError::Unknown(e.into()));
return Err(ResetPasswordError::Unknown(e));
}
};
@@ -190,4 +194,77 @@ where
.await
.unwrap_or(None)
}
async fn find_client_by_id(&self, id: uuid::Uuid) -> Option<Client> {
self.repo.find_client_by_id(id).await.ok().flatten()
}
async fn generate_authorization_code(
&self,
user: &User,
req: AuthorizeRequest,
) -> Result<AuthorizationResponse, anyhow::Error> {
let Some(client) = self.repo.find_client_by_id(req.client_id()).await? else {
return Err(anyhow::anyhow!("Client not found"));
};
if client.redirect_uri() != &req.redirect_uri() {
return Err(anyhow::anyhow!("Invalid redirect uri"));
}
let code = self
.repo
.create_authorization_code(
user.id(),
client.id(),
req.code_challenge(),
req.code_challenge_method().unwrap_or_default(),
)
.await?;
Ok(AuthorizationResponse::new(code, req.state()))
}
async fn create_token(
&self,
req: AuthorizationCodeRequest,
) -> Result<Option<TokenSubject>, TokenError> {
if req.grant_type() != GrantType::AuthorizationCode {
return Err(TokenError::InvalidRequest);
}
let code = req.code();
let Some(token) = self.repo.get_token_subject(code).await? else {
return Err(TokenError::InvalidRequest);
};
let code_verifier = req.code_verifier();
let code_challenge = match token.code_challenge_method() {
CodeChallengeMethod::Plain => {
use base64::prelude::*;
BASE64_URL_SAFE_NO_PAD.encode(code_verifier.to_string())
}
CodeChallengeMethod::Sha256 => {
use base64::prelude::*;
BASE64_URL_SAFE_NO_PAD.encode(sha256::digest(code_verifier.to_string()))
}
};
if token.code_challenge() != code_challenge {
return Err(TokenError::InvalidRequest);
}
let Some(client) = self.repo.find_client_by_id(token.client_id()).await? else {
return Err(TokenError::InvalidRequest); // no such client
};
if &req.redirect_uri() != client.redirect_uri() {
return Err(TokenError::InvalidRequest); // invalid redirect uri
}
let _ = self.repo.delete_token(req.code()).await;
Ok(Some(token))
}
}

View File

@@ -11,16 +11,23 @@ use pages::{
},
dashboard::DashboardPage,
error::{AppError, ErrorTemplate},
oauth2::authorize::AuthorizePage,
};
use crate::domain::api::prelude::User;
#[component]
pub fn App() -> impl IntoView {
provide_meta_context();
let trigger_user = create_rw_signal(true);
let trigger_update = create_rw_signal(false);
let trigger_direct = create_rw_signal(None::<String>);
let user = create_resource(trigger_user, move |_| async move {
super::check_user().await.unwrap()
let user_signal = create_rw_signal(None::<User>);
let user_resource = create_local_resource(trigger_update, move |_| async move {
let user = super::check_user().await.unwrap();
user_signal.set(user);
});
view! {
@@ -47,17 +54,17 @@ pub fn App() -> impl IntoView {
}>
<main class="h-screen overflow-auto dark:base-100 dark:text-white">
<Routes>
<Route path="/auth" view=move || {
<Route path="auth" view=move || {
view! {
<Suspense>
<Show when=move || user().is_some_and(|u| u.is_some())>
<Redirect path="/" />
<Show when=move || user_resource().is_some_and(|_|user_signal().is_some())>
<Redirect path={ trigger_direct().unwrap_or(String::from("/")) } />
</Show>
</Suspense>
<Outlet />
}
}>
<Route path="login" view=move || view! { <LoginPage user_signal=trigger_user /> } />
<Route path="login" view=move || view! { <LoginPage trigger_signal=trigger_update direct_signal=trigger_direct /> } />
<Route path="register" view=RegisterPage />
<Route path="forgot" view=ForgotPage />
<Route path="reset/:token" view=ResetPage />
@@ -66,8 +73,11 @@ pub fn App() -> impl IntoView {
<Route path="" view=move || {
view! {
<Suspense>
<Show when=move || user().is_some_and(|u| u.is_none())>
<Redirect path="/auth/login" />
<Show when=move || user_resource().is_some_and(|_|user_signal().is_none())>
<Redirect path={
use base64::prelude::*;
format!("/auth/login?c={}", BASE64_URL_SAFE_NO_PAD.encode(format!("{}{}", (leptos_router::use_location().pathname)(), (leptos_router::use_location().query)().to_query_string())))
} />
</Show>
</Suspense>
<Outlet />
@@ -75,13 +85,22 @@ pub fn App() -> impl IntoView {
}>
<Route path="" view=move || view! {
<Suspense>
<Show when=move || user().is_some_and(|u| u.is_some())>
<DashboardPage user={ user().unwrap().unwrap() } />
<Show when=move || user_signal().is_some()>
<DashboardPage user={ user_signal().unwrap() } />
</Show>
</Suspense>
} />
</Route> // dashboard
<Route path="/auth/logout" view=move || view! { <LogoutPage user_signal=trigger_user /> } />
<Route path="oauth2/authorize" view=move || view! {
<Suspense>
<Show when=move || user_signal().is_some()>
<AuthorizePage user={ user_signal } />
</Show>
</Suspense>
} />
</Route> // Logged in
<Route path="auth/logout" view=move || view! { <LogoutPage trigger_signal=trigger_update user_signal /> } />
</Routes>
</main>
</Router>

View File

@@ -1,10 +1,12 @@
use leptos::*;
use leptos_router::*;
use crate::domain::leptos::app::components::alert::Alert;
use crate::domain::api::prelude::*;
use super::auth_base::AuthBase;
use crate::domain::leptos::app::components::alert::Alert as AlertView;
#[server]
async fn login_action(email: String, password: String) -> Result<(), ServerFnError<String>> {
use crate::domain::api::prelude::*;
@@ -26,13 +28,27 @@ async fn login_action(email: String, password: String) -> Result<(), ServerFnErr
/// Renders the home page of your application.
#[component]
pub fn LoginPage(user_signal: RwSignal<bool>) -> impl IntoView {
pub fn LoginPage(trigger_signal: RwSignal<bool>, direct_signal: RwSignal<Option<String>>) -> impl IntoView {
let submit = Action::<LoginAction, _>::server();
let response = submit.value().read_only();
let callback = move || {
let query = use_query_map();
let Some(c) = query.with(|q| q.get("c").cloned()) else {
return Some("/".to_string());
};
use base64::prelude::*;
let dec = BASE64_URL_SAFE_NO_PAD.decode(c).unwrap_or_default();
let c = String::from_utf8_lossy(&dec);
Some(c.to_string())
};
create_effect(move |_| {
if response.get().is_some() {
user_signal.set(!user_signal.get_untracked());
trigger_signal.set(!trigger_signal.get_untracked());
direct_signal.set(callback());
}
});
@@ -50,25 +66,22 @@ pub fn LoginPage(user_signal: RwSignal<bool>) -> impl IntoView {
<Suspense>
<Show when=move || response.get().is_some_and(|e| e.is_err())>
<div role="alert" class="alert alert-error my-2">
<i class="fas fa-exclamation-circle"></i>
<span>
<AlertView alert={Alert::Error}>
{ move || if let Some(Err(e)) = response.get() {
{format!("{}", e)}.into_view()
} else {
().into_view()
}}
</span>
</div>
</AlertView>
</Show>
</Suspense>
<Suspense>
<Show when=move || flash().is_some_and(|u| u.is_some())>
<Alert alert={ flash().unwrap().unwrap().alert() }>
<AlertView alert={ flash().unwrap().unwrap().alert() }>
{ flash().unwrap().unwrap().message() }
</Alert>
</AlertView>
</Show>
</Suspense>

View File

@@ -1,5 +1,7 @@
use leptos::*;
use leptos_router::Redirect;
use leptos_router::{use_query_map, Redirect};
use crate::domain::api::prelude::User;
#[server]
async fn logout_action() -> Result<(), ServerFnError<String>> {
@@ -12,22 +14,31 @@ async fn logout_action() -> Result<(), ServerFnError<String>> {
}
#[component]
pub fn LogoutPage(user_signal: RwSignal<bool>) -> impl IntoView {
let submit = Action::<LogoutAction, _>::server();
let response = submit.value().read_only();
pub fn LogoutPage(
trigger_signal: RwSignal<bool>,
user_signal: RwSignal<Option<User>>,
) -> impl IntoView {
let direct_signal = create_rw_signal(None::<String>);
create_effect(move |_| {
if response.get().is_some() {
user_signal.set(!user_signal.get_untracked());
}
});
create_local_resource(
|| (),
move |_| async move {
logout_action().await.unwrap();
trigger_signal.set(!trigger_signal.get_untracked());
submit.dispatch(LogoutAction {});
let query = use_query_map();
if let Some(c) = query.with_untracked(|q| q.get("c").cloned()) {
direct_signal.set(Some(c));
} else {
direct_signal.set(Some("Lw".to_string()));
}
},
);
view! {
<Suspense>
<Show when=move || response().is_some()>
<Redirect path="/" />
<Show when=move || (direct_signal().is_some() && user_signal().is_none())>
<Redirect path={ format!("/auth/login?c={}", direct_signal().unwrap()) } />
</Show>
</Suspense>
}

View File

@@ -19,9 +19,7 @@ async fn reset_action(token: String, password: String, confirm_password: String)
app.reset_password(&token.into(), &password).await.map_err(|e| format!("{}", e))?;
let flash = FlashMessage::new("login",
format!(
"Your password has been reset."
)).with_alert(Alert::Success);
"Your password has been reset.".to_string()).with_alert(Alert::Success);
flashbag.set(flash);

View File

@@ -1,3 +1,4 @@
pub mod auth;
pub mod dashboard;
pub mod error;
pub mod oauth2;

View File

@@ -0,0 +1,279 @@
use std::str::FromStr;
use std::sync::Arc;
use leptos::*;
use leptos_router::*;
use serde::{Deserialize, Serialize};
use crate::domain::api::prelude::*;
use crate::domain::leptos::app::components::alert::Alert as AlertView;
use crate::domain::leptos::app::pages::auth::auth_base::AuthBase;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct AuthorizationResponse {
code: String,
state: String,
}
#[server]
async fn authorize_action(form: AuthorizeQuery) -> Result<(), ServerFnError<String>> {
use crate::domain::api::prelude::*;
use crate::domain::leptos::check_user;
let app = use_context::<AppService>().unwrap();
let Some(client) = app.find_client_by_id(form.client_id()).await else {
return Err(ServerFnError::WrappedServerError(
"Invalid Client ID".to_string(),
));
};
let Some(user) = check_user().await? else {
return Err(ServerFnError::WrappedServerError("No user".to_string()));
};
let response: AuthorizationResponse = app
.generate_authorization_code(&user, form.into())
.await
.map_err(|e| format!("{}", e))?;
let qs = serde_qs::to_string(&response).map_err(|e| format!("{}", e))?;
leptos_axum::redirect(&format!("{}?{}", client.redirect_uri(), qs));
Ok(())
}
impl IntoAttribute for CodeChallengeMethod {
fn into_attribute(self) -> Attribute {
Attribute::String(self.to_string().into())
}
fn into_attribute_boxed(self: Box<Self>) -> Attribute {
Attribute::String(self.to_string().into())
}
}
// Make the types and validators for specific shit like response type and code challenge method
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct AuthorizeQuery {
client_id: uuid::Uuid,
response_type: ResponseType, // Make type (enum:code,)
state: Option<String>, // random string for CSRF protection
code_challenge: String, // pkce
code_challenge_method: Option<CodeChallengeMethod>, // Make type (enum:sha256,) hashing algo
redirect_uri: RedirectUri, // Make type
scope: Option<String>, // space seperated string with permissions
}
impl From<AuthorizeQuery> for AuthorizeRequest {
fn from(value: AuthorizeQuery) -> Self {
Self::new(
value.client_id(),
value.response_type(),
value.state(),
value.code_challenge(),
value.code_challenge_method(),
value.redirect_uri(),
value.scope(),
)
}
}
impl AuthorizeQuery {
pub fn client_id(&self) -> uuid::Uuid {
self.client_id
}
pub fn response_type(&self) -> ResponseType {
self.response_type.clone()
}
pub fn state(&self) -> Option<String> {
self.state.clone()
}
pub fn code_challenge(&self) -> String {
self.code_challenge.clone()
}
pub fn code_challenge_method(&self) -> Option<CodeChallengeMethod> {
self.code_challenge_method.clone()
}
pub fn redirect_uri(&self) -> RedirectUri {
self.redirect_uri.clone()
}
pub fn scope(&self) -> Option<String> {
self.scope.clone()
}
}
impl Params for AuthorizeQuery {
fn from_map(map: &ParamsMap) -> Result<Self, ParamsError> {
let client_id: uuid::Uuid = match map.get("client_id") {
Some(c) => uuid::Uuid::from_str(c).map_err(|e| ParamsError::Params(Arc::new(e)))?,
None => return Err(ParamsError::MissingParam("client_id".to_string())),
};
let response_type: ResponseType = match map.get("response_type") {
Some(c) => ResponseType::from_str(c).map_err(|e| ParamsError::Params(Arc::new(e)))?,
None => return Err(ParamsError::MissingParam("response_type".to_string())),
};
let state = map.get("state").cloned();
let Some(code_challenge) = map.get("code_challenge").cloned() else {
return Err(ParamsError::MissingParam("code_challenge".to_string()));
};
let code_challenge_method = map
.get("code_challenge_method")
.map(|c| CodeChallengeMethod::from_str(c).map_err(|e| ParamsError::Params(Arc::new(e))))
.transpose()?;
let redirect_uri: RedirectUri = match map.get("redirect_uri") {
Some(c) => RedirectUri::new(c),
None => return Err(ParamsError::MissingParam("redirect_uri".to_string())),
};
let scope = map.get("scope").cloned();
Ok(Self {
client_id,
response_type,
state,
code_challenge,
code_challenge_method,
redirect_uri,
scope,
})
}
}
#[server]
async fn get_client(client_id: uuid::Uuid) -> Result<Client, ServerFnError<String>> {
use crate::domain::api::prelude::*;
let app = use_context::<AppService>().unwrap();
let Some(client) = app.find_client_by_id(client_id).await else {
return Err(ServerFnError::WrappedServerError(
"Invalid Client ID".to_string(),
));
};
Ok(client)
}
/// Renders the home page of your application.
#[component]
pub fn AuthorizePage(user: RwSignal<Option<User>>) -> impl IntoView {
let submit = Action::<AuthorizeAction, _>::server();
let response = submit.value().read_only();
let query = use_query::<AuthorizeQuery>();
let client_signal = create_rw_signal(None::<Client>);
let client_resource = create_local_resource(
|| (),
move |_| async move {
let query = query.get_untracked().map_err(|e| format!("{}", e))?;
let client = get_client(query.client_id())
.await
.map_err(|e| format!("{}", e))?;
if client.redirect_uri() != &query.redirect_uri() {
return Err("Invalid redirect uri".to_string());
}
client_signal.set(Some(client));
Ok::<(), String>(())
},
);
view! {
<AuthBase>
<Suspense>
<Show when=move || response.get().is_some_and(|e| e.is_err())>
<AlertView alert={Alert::Error}>
{ move || if let Some(Err(e)) = response.get() {
{format!("{}", e)}.into_view()
} else {
().into_view()
}}
</AlertView>
</Show>
</Suspense>
<Show when=move || query.get().is_err()>
<AlertView alert={Alert::Error}>
{ move || if let Err(e) = query.get().map_err(|e| { leptos::logging::warn!("{:#?}", e); "Invalid parameters".to_string() }) {
{e.to_string()}.into_view()
} else {
().into_view()
}}
</AlertView>
</Show>
<Suspense>
<Show when=move || client_resource.get().is_some_and(|e| e.is_err())>
<AlertView alert={Alert::Error}>
{ move || if let Some(Err(e)) = client_resource.get() {
{e.to_string()}.into_view()
} else {
().into_view()
}}
</AlertView>
</Show>
</Suspense>
<Show when=move || query.get().is_ok() && client_resource.get().is_some_and(|e| e.is_ok())>
<ActionForm action=submit class="w-full">
<div class="flex flex-col gap-2">
<div class="text-center text-sm">
"Signed in as "{move || user.get().unwrap().email().to_string()}
</div>
<Show when=move || response().is_some_and(|e| e.is_ok())>
<AlertView alert={Alert::Success}>
"You can now close this window."
</AlertView>
</Show>
<Show when=move || response().is_none()>
<div class="text-center text-sm">
<strong>{move || client_signal().unwrap().name().to_string() }</strong>" is requesting access to your account"
</div>
<div>
<input type="hidden" name="form[client_id]" value=move || query().unwrap().client_id().to_string() />
<input type="hidden" name="form[response_type]" value=move || query().unwrap().response_type().to_string() />
<input type="hidden" name="form[state]" value=move || query().unwrap().state() />
<input type="hidden" name="form[code_challenge]" value=move || query().unwrap().code_challenge().to_string() />
<input type="hidden" name="form[code_challenge_method]" value=move || query().unwrap().code_challenge_method() />
<input type="hidden" name="form[redirect_uri]" value=move || query().unwrap().redirect_uri().to_string() />
<input type="hidden" name="form[scope]" value=move || query().unwrap().scope() />
<input type="submit" value="Authorize" class="btn btn-primary btn-block" />
</div>
<div class="text-center text-sm">
"Not "{move || user.get().unwrap().email().to_string()}"? "<a href={move || {
use base64::prelude::*;
format!("/auth/logout?c={}", BASE64_URL_SAFE_NO_PAD.encode(format!("{}{}", (leptos_router::use_location().pathname)(), (leptos_router::use_location().query)().to_query_string())))
}} class="link">"Logout"</a>"!"
</div>
</Show>
</div>
</ActionForm>
</Show>
</AuthBase>
}
}

View File

@@ -0,0 +1 @@
pub mod authorize;

View File

@@ -46,7 +46,7 @@ where
pub fn new(name: &str, message: S) -> Self {
Self {
name: name.to_string(),
message: message.into(),
message,
alert: Alert::None,
}
}
@@ -82,14 +82,12 @@ impl FlashBag {
pub fn get(&self, flash_name: &str) -> Option<Flash> {
let name = format!("__flash:{}__", flash_name);
let Some(message) = self.session.get(&name) else {
return None;
};
let message = self.session.get(&name)?;
let alert_name = format!("__flash_alert:{}__", flash_name);
let alert = self.session.get(&alert_name).unwrap_or(Alert::None);
self.clear(&flash_name);
self.clear(flash_name);
Some(Flash {
name,

View File

@@ -12,9 +12,7 @@ use axum::{
use axum_session::{SessionAnyPool, SessionConfig, SessionLayer, SessionStore};
use axum_session_sqlx::SessionPgPool;
use handlers::{
fileserv::file_and_error_handler,
leptos::{leptos_routes_handler, server_fn_handler},
user::activate_account,
fileserv::file_and_error_handler, leptos::{leptos_routes_handler, server_fn_handler}, oauth, user::activate_account
};
use leptos_axum::{generate_route_list, LeptosRoutes};
use state::AppState;
@@ -59,6 +57,7 @@ impl HttpServer {
);
let router = axum::Router::new()
.nest("/oauth2", oauth::routes())
.route("/auth/activate/:token", get(activate_account))
.route("/api/*fn_name", post(server_fn_handler))
.leptos_routes_with_handler(generate_route_list(App), get(leptos_routes_handler))

View File

@@ -1,3 +1,4 @@
pub mod fileserv;
pub mod leptos;
pub mod oauth;
pub mod user;

View File

@@ -0,0 +1,192 @@
use std::{fmt::Display, str::FromStr};
use axum::{extract::State, routing::post, Json, Router};
use derive_more::derive::Display;
use http::StatusCode;
use jsonwebtoken::{encode, EncodingKey, Header};
use rand::{distributions::Alphanumeric, Rng};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::{config::Config, domain::api::prelude::*, inbound::http::state::AppState};
pub fn routes<S>() -> axum::Router<AppState<S>>
where
S: ApiService,
{
Router::new().route("/token", post(token))
}
#[derive(Debug, Error)]
pub enum GrantTypeError {
#[error("The grant type is not valid.")]
Invalid,
#[error(transparent)]
Unknown(#[from] anyhow::Error),
// to be extended as new error scenarios are introduced
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum GrantType {
#[serde(rename = "authorization_code")]
AuthorizationCode,
}
impl Display for GrantType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
GrantType::AuthorizationCode => "authorization_code",
}
)
}
}
impl FromStr for GrantType {
type Err = GrantTypeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"authorization_code" => Ok(Self::AuthorizationCode),
_ => Err(GrantTypeError::Invalid),
}
}
}
#[derive(Clone, Display, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(from = "String")]
pub struct CodeVerifier(String);
impl From<String> for CodeVerifier {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<&str> for CodeVerifier {
fn from(value: &str) -> Self {
Self(value.to_string())
}
}
impl CodeVerifier {
pub fn new() -> Self {
let token: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(120)
.map(char::from)
.collect();
Self(token)
}
}
impl Default for CodeVerifier {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthorizationCodeRequest {
grant_type: GrantType,
code: AuthorizationCode,
redirect_uri: RedirectUri,
client_id: uuid::Uuid,
code_verifier: CodeVerifier,
}
impl AuthorizationCodeRequest {
pub fn grant_type(&self) -> GrantType {
self.grant_type.clone()
}
pub fn code(&self) -> AuthorizationCode {
self.code.clone()
}
pub fn redirect_uri(&self) -> RedirectUri {
self.redirect_uri.clone()
}
pub fn client_id(&self) -> uuid::Uuid {
self.client_id
}
pub fn code_verifier(&self) -> CodeVerifier {
self.code_verifier.clone()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TokenClaims<T>
where
T: Serialize,
{
pub sub: T,
pub iat: usize,
pub exp: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthorizationCodeResponse {
token: String,
}
async fn token<S: ApiService>(
State(app_state): State<AppState<S>>,
body: String,
) -> Result<Json<AuthorizationCodeResponse>, (StatusCode, String)> {
let request: AuthorizationCodeRequest = serde_qs::from_str(&body).map_err(|e| {
tracing::error!("{:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
String::from("Internal Server Error"),
)
})?;
let app = app_state.api_service();
let token = app.create_token(request).await.map_err(|e| {
tracing::error!("{:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
String::from("Internal Server Error"),
)
})?;
let token = create_token(token).map_err(|e| {
tracing::error!("{:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
String::from("Internal Server Error"),
)
})?;
Ok(Json(AuthorizationCodeResponse { token }))
}
fn create_token<T>(sub: T) -> Result<String, anyhow::Error>
where
T: Serialize,
{
let config = Config::from_env()?;
let now = time::OffsetDateTime::now_utc();
let iat = now.unix_timestamp() as usize;
let exp = (now + time::Duration::days(30)).unix_timestamp() as usize;
let claims = TokenClaims { sub, iat, exp };
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(config.jwt_secret.as_bytes()),
)?;
Ok(token)
}

View File

@@ -1,7 +1,9 @@
#![feature(trait_alias)]
pub static BASE_URL: &str = "https://avam.avii.nl";
pub static PROJECT_NAME: &str = "Avii's Virtual Airline Manager";
pub static COPYRIGHT: &str = "Avii's Virtual Airline Manager © 2024";
pub mod domain;
#[cfg(feature = "ssr")]

View File

@@ -3,6 +3,7 @@ use lettre::AsyncTransport;
use crate::domain::api::ports::UserNotifier;
use crate::domain::api::models::user::*;
use crate::BASE_URL;
use super::DangerousLettre;
@@ -10,7 +11,7 @@ impl UserNotifier for DangerousLettre {
async fn user_created(&self, user: &User, token: &ActivationToken) {
let mut context = tera::Context::new();
let url = format!("http://127.0.0.1:3000/auth/activate/{}", token); // Move base url to env
let url = format!("{}/auth/activate/{}", BASE_URL, token); // Move base url to env
context.insert("activate_url", &url);
@@ -28,7 +29,7 @@ impl UserNotifier for DangerousLettre {
async fn forgot_password(&self, user: &User, token: &PasswordResetToken) {
let mut context = tera::Context::new();
let url = format!("http://127.0.0.1:3000/auth/reset/{}", token); // Move base url to env
let url = format!("{}/auth/reset/{}", BASE_URL, token); // Move base url to env
context.insert("reset_url", &url);

View File

@@ -1,3 +1,4 @@
pub mod oauth_repository;
pub mod user_repository;
use std::str::FromStr;

View File

@@ -0,0 +1,179 @@
use anyhow::Context;
use sqlx::Executor;
// use sqlx::QueryBuilder;
use sqlx::Row;
use crate::domain::api::models::oauth::*;
use crate::domain::api::ports::OAuthRepository;
use super::Postgres;
impl OAuthRepository for Postgres {
async fn find_client_by_id(&self, id: uuid::Uuid) -> Result<Option<Client>, anyhow::Error> {
let query = sqlx::query(
r#"
SELECT
id,
user_id,
name,
secret,
redirect_uri
FROM clients WHERE id = $1
"#,
)
.bind(id);
let row = self
.pool
.fetch_optional(query)
.await
.context("failed to execute SQL transaction")?;
let Some(row) = row else {
return Ok(None);
};
let id = row.get("id");
let user_id = row.get("user_id");
let name = ClientName::new(row.get("name"));
let secret = ClientSecret::from(row.get::<&str, &str>("secret"));
let redirect_uri = RedirectUri::new(row.get("redirect_uri"));
Ok(Some(Client::new(id, user_id, name, secret, redirect_uri)))
}
async fn create_authorization_code(
&self,
user_id: uuid::Uuid,
client_id: uuid::Uuid,
code_challenge: String,
code_challenge_method: CodeChallengeMethod,
) -> Result<AuthorizationCode, anyhow::Error> {
let mut tx = self
.pool
.begin()
.await
.context("Failed to start sql transaction")?;
let code = AuthorizationCode::new();
let query = sqlx::query(
r#"
INSERT INTO authorization_code (code, user_id, client_id, code_challenge, code_challenge_method)
VALUES ($1, $2, $3, $4, $5)
"#,
)
.bind(code.to_string())
.bind(user_id)
.bind(client_id)
.bind(code_challenge)
.bind(code_challenge_method.to_string());
tx.execute(query)
.await
.map_err(|e| CreateAuthorizationCodeError::Unknown(e.into()))?;
let query = sqlx::query(
r#"
INSERT INTO authorized_clients (user_id, client_id)
VALUES ($1, $2) ON CONFLICT DO NOTHING
"#,
)
.bind(user_id)
.bind(client_id);
tx.execute(query)
.await
.map_err(|e| CreateAuthorizationCodeError::Unknown(e.into()))?;
tx.commit()
.await
.context("failed to commit SQL transaction")?;
Ok(code)
}
async fn is_authorized_client(
&self,
user_id: uuid::Uuid,
client_id: uuid::Uuid,
) -> Result<bool, anyhow::Error> {
let query = sqlx::query(
r#"
SELECT
*
FROM authorized_clients WHERE user_id = $1 AND client_id = $2
"#,
)
.bind(user_id)
.bind(client_id);
Ok(self
.pool
.fetch_optional(query)
.await
.context("failed to execute SQL transaction")?
.is_some())
}
async fn get_token_subject(
&self,
code: AuthorizationCode,
) -> Result<Option<TokenSubject>, anyhow::Error> {
let query = sqlx::query(
r#"
SELECT
user_id,
client_id,
code_challenge,
code_challenge_method
FROM authorization_code WHERE code = $1
"#,
)
.bind(code.to_string());
let Some(row) = self
.pool
.fetch_optional(query)
.await
.context("failed to execute SQL transaction")?
else {
return Ok(None);
};
let user_id: uuid::Uuid = row.get("user_id");
let client_id: uuid::Uuid = row.get("client_id");
let code_challenge: String = row.get("code_challenge");
let code_challenge_method: CodeChallengeMethod = row
.get::<String, &str>("code_challenge_method")
.try_into()?;
Ok(Some(TokenSubject::new(
user_id,
client_id,
code_challenge,
code_challenge_method,
)))
}
async fn delete_token(&self, code: AuthorizationCode) -> Result<(), anyhow::Error> {
let mut tx = self
.pool
.begin()
.await
.context("Failed to start sql transaction")?;
let query =
sqlx::query("DELETE FROM authorization_code WHERE code = $1").bind(code.to_string());
tx.execute(query)
.await
.context("failed to execute SQL transaction")?;
tx.commit()
.await
.context("failed to commit SQL transaction")?;
Ok(())
}
}

View File

@@ -216,7 +216,7 @@ impl UserRepository for Postgres {
let id = row.get("user_id");
Ok(self.find_user_by_id(id).await?)
self.find_user_by_id(id).await
}
async fn find_user_by_password_reset_token(
@@ -244,7 +244,7 @@ impl UserRepository for Postgres {
let id = row.get("user_id");
Ok(self.find_user_by_id(id).await?)
self.find_user_by_id(id).await
}
async fn update_user(
@@ -295,10 +295,10 @@ impl UserRepository for Postgres {
Ok((
ent.clone(),
User::new(
ent.id().clone(),
ent.id(),
new_email.clone(),
new_password.clone(),
new_verified.clone(),
*new_verified,
),
))
}