Files
avam/src/lib/inbound/http/handlers/oauth.rs
2024-10-18 18:20:44 +02:00

217 lines
4.9 KiB
Rust

use std::{fmt::Display, str::FromStr};
use axum::{extract::State, routing::post, Json, Router};
use derive_more::derive::Display;
use http::StatusCode;
use jsonwebtoken::{encode, EncodingKey, Header};
use rand::{distributions::Alphanumeric, Rng};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::{config::Config, domain::api::prelude::*, inbound::http::state::AppState};
pub fn routes<S>() -> axum::Router<AppState<S>>
where
S: ApiService,
{
Router::new().route("/token", post(token))
}
#[derive(Debug, Error)]
pub enum GrantTypeError {
#[error("The grant type is not valid.")]
Invalid,
#[error(transparent)]
Unknown(#[from] anyhow::Error),
// to be extended as new error scenarios are introduced
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum GrantType {
#[serde(rename = "authorization_code")]
AuthorizationCode,
}
impl Display for GrantType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
GrantType::AuthorizationCode => "authorization_code",
}
)
}
}
impl FromStr for GrantType {
type Err = GrantTypeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"authorization_code" => Ok(Self::AuthorizationCode),
_ => Err(GrantTypeError::Invalid),
}
}
}
#[derive(Clone, Display, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(from = "String")]
pub struct CodeVerifier(String);
impl From<String> for CodeVerifier {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<&str> for CodeVerifier {
fn from(value: &str) -> Self {
Self(value.to_string())
}
}
impl CodeVerifier {
pub fn new() -> Self {
let token: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(120)
.map(char::from)
.collect();
Self(token)
}
}
impl Default for CodeVerifier {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthorizationCodeRequest {
grant_type: GrantType,
code: AuthorizationCode,
redirect_uri: RedirectUri,
client_id: uuid::Uuid,
code_verifier: CodeVerifier,
}
impl AuthorizationCodeRequest {
pub fn grant_type(&self) -> GrantType {
self.grant_type.clone()
}
pub fn code(&self) -> AuthorizationCode {
self.code.clone()
}
pub fn redirect_uri(&self) -> RedirectUri {
self.redirect_uri.clone()
}
pub fn client_id(&self) -> uuid::Uuid {
self.client_id
}
pub fn code_verifier(&self) -> CodeVerifier {
self.code_verifier.clone()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VerifyClientAuthorizationRequest {
user_id: uuid::Uuid,
client_id: uuid::Uuid,
}
impl VerifyClientAuthorizationRequest {
pub fn new(user_id: uuid::Uuid, client_id: uuid::Uuid) -> Self {
Self { client_id, user_id }
}
pub fn user_id(&self) -> uuid::Uuid {
self.user_id
}
pub fn client_id(&self) -> uuid::Uuid {
self.client_id
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TokenClaims<T>
where
T: Serialize,
{
pub sub: T,
pub iat: usize,
pub exp: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthorizationCodeResponse {
token: String,
}
async fn token<S: ApiService>(
State(app_state): State<AppState<S>>,
body: String,
) -> Result<Json<AuthorizationCodeResponse>, (StatusCode, String)> {
let request: AuthorizationCodeRequest = serde_qs::from_str(&body).map_err(|e| {
tracing::error!("{:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
String::from("Internal Server Error"),
)
})?;
let app = app_state.api_service();
let token = app.create_token(request).await.map_err(|e| {
tracing::error!("{:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
String::from("Internal Server Error"),
)
})?;
let token = create_token(token).map_err(|e| {
tracing::error!("{:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
String::from("Internal Server Error"),
)
})?;
Ok(Json(AuthorizationCodeResponse { token }))
}
fn create_token<T>(sub: T) -> Result<String, anyhow::Error>
where
T: Serialize,
{
let config = Config::from_env()?;
let now = time::OffsetDateTime::now_utc();
let iat = now.unix_timestamp() as usize;
let exp = (now + time::Duration::days(30)).unix_timestamp() as usize;
let claims = TokenClaims {
sub: serde_qs::to_string(&sub)?,
iat,
exp,
};
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(config.jwt_secret.as_bytes()),
)?;
Ok(token)
}