217 lines
4.9 KiB
Rust
217 lines
4.9 KiB
Rust
use std::{fmt::Display, str::FromStr};
|
|
|
|
use axum::{extract::State, routing::post, Json, Router};
|
|
use derive_more::derive::Display;
|
|
use http::StatusCode;
|
|
use jsonwebtoken::{encode, EncodingKey, Header};
|
|
use rand::{distributions::Alphanumeric, Rng};
|
|
use serde::{Deserialize, Serialize};
|
|
use thiserror::Error;
|
|
|
|
use crate::{config::Config, domain::api::prelude::*, inbound::http::state::AppState};
|
|
|
|
pub fn routes<S>() -> axum::Router<AppState<S>>
|
|
where
|
|
S: ApiService,
|
|
{
|
|
Router::new().route("/token", post(token))
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum GrantTypeError {
|
|
#[error("The grant type is not valid.")]
|
|
Invalid,
|
|
#[error(transparent)]
|
|
Unknown(#[from] anyhow::Error),
|
|
// to be extended as new error scenarios are introduced
|
|
}
|
|
|
|
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
|
|
pub enum GrantType {
|
|
#[serde(rename = "authorization_code")]
|
|
AuthorizationCode,
|
|
}
|
|
|
|
impl Display for GrantType {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(
|
|
f,
|
|
"{}",
|
|
match self {
|
|
GrantType::AuthorizationCode => "authorization_code",
|
|
}
|
|
)
|
|
}
|
|
}
|
|
|
|
impl FromStr for GrantType {
|
|
type Err = GrantTypeError;
|
|
|
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
match s {
|
|
"authorization_code" => Ok(Self::AuthorizationCode),
|
|
_ => Err(GrantTypeError::Invalid),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Display, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
|
|
#[serde(from = "String")]
|
|
pub struct CodeVerifier(String);
|
|
|
|
impl From<String> for CodeVerifier {
|
|
fn from(value: String) -> Self {
|
|
Self(value)
|
|
}
|
|
}
|
|
|
|
impl From<&str> for CodeVerifier {
|
|
fn from(value: &str) -> Self {
|
|
Self(value.to_string())
|
|
}
|
|
}
|
|
|
|
impl CodeVerifier {
|
|
pub fn new() -> Self {
|
|
let token: String = rand::thread_rng()
|
|
.sample_iter(&Alphanumeric)
|
|
.take(120)
|
|
.map(char::from)
|
|
.collect();
|
|
|
|
Self(token)
|
|
}
|
|
}
|
|
|
|
impl Default for CodeVerifier {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct AuthorizationCodeRequest {
|
|
grant_type: GrantType,
|
|
code: AuthorizationCode,
|
|
redirect_uri: RedirectUri,
|
|
client_id: uuid::Uuid,
|
|
code_verifier: CodeVerifier,
|
|
}
|
|
|
|
impl AuthorizationCodeRequest {
|
|
pub fn grant_type(&self) -> GrantType {
|
|
self.grant_type.clone()
|
|
}
|
|
|
|
pub fn code(&self) -> AuthorizationCode {
|
|
self.code.clone()
|
|
}
|
|
|
|
pub fn redirect_uri(&self) -> RedirectUri {
|
|
self.redirect_uri.clone()
|
|
}
|
|
|
|
pub fn client_id(&self) -> uuid::Uuid {
|
|
self.client_id
|
|
}
|
|
|
|
pub fn code_verifier(&self) -> CodeVerifier {
|
|
self.code_verifier.clone()
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct VerifyClientAuthorizationRequest {
|
|
user_id: uuid::Uuid,
|
|
client_id: uuid::Uuid,
|
|
}
|
|
|
|
impl VerifyClientAuthorizationRequest {
|
|
pub fn new(user_id: uuid::Uuid, client_id: uuid::Uuid) -> Self {
|
|
Self { client_id, user_id }
|
|
}
|
|
|
|
pub fn user_id(&self) -> uuid::Uuid {
|
|
self.user_id
|
|
}
|
|
|
|
pub fn client_id(&self) -> uuid::Uuid {
|
|
self.client_id
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct TokenClaims<T>
|
|
where
|
|
T: Serialize,
|
|
{
|
|
pub sub: T,
|
|
pub iat: usize,
|
|
pub exp: usize,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct AuthorizationCodeResponse {
|
|
token: String,
|
|
}
|
|
|
|
async fn token<S: ApiService>(
|
|
State(app_state): State<AppState<S>>,
|
|
body: String,
|
|
) -> Result<Json<AuthorizationCodeResponse>, (StatusCode, String)> {
|
|
let request: AuthorizationCodeRequest = serde_qs::from_str(&body).map_err(|e| {
|
|
tracing::error!("{:?}", e);
|
|
|
|
(
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
String::from("Internal Server Error"),
|
|
)
|
|
})?;
|
|
|
|
let app = app_state.api_service();
|
|
let token = app.create_token(request).await.map_err(|e| {
|
|
tracing::error!("{:?}", e);
|
|
|
|
(
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
String::from("Internal Server Error"),
|
|
)
|
|
})?;
|
|
|
|
let token = create_token(token).map_err(|e| {
|
|
tracing::error!("{:?}", e);
|
|
|
|
(
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
String::from("Internal Server Error"),
|
|
)
|
|
})?;
|
|
|
|
Ok(Json(AuthorizationCodeResponse { token }))
|
|
}
|
|
|
|
fn create_token<T>(sub: T) -> Result<String, anyhow::Error>
|
|
where
|
|
T: Serialize,
|
|
{
|
|
let config = Config::from_env()?;
|
|
|
|
let now = time::OffsetDateTime::now_utc();
|
|
let iat = now.unix_timestamp() as usize;
|
|
let exp = (now + time::Duration::days(30)).unix_timestamp() as usize;
|
|
|
|
let claims = TokenClaims {
|
|
sub: serde_qs::to_string(&sub)?,
|
|
iat,
|
|
exp,
|
|
};
|
|
|
|
let token = encode(
|
|
&Header::default(),
|
|
&claims,
|
|
&EncodingKey::from_secret(config.jwt_secret.as_bytes()),
|
|
)?;
|
|
|
|
Ok(token)
|
|
}
|