diff --git a/Cargo.lock b/Cargo.lock index 42b8ae2..3984a61 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2277,6 +2277,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_qs" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd34f36fe4c5ba9654417139a9b3a20d2e1de6012ee678ad14d240c22c78d8d6" +dependencies = [ + "percent-encoding", + "serde", + "thiserror", +] + [[package]] name = "serde_spanned" version = "0.6.5" @@ -3093,6 +3104,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "serde_qs", "tera", "time", "tokio 1.37.0", diff --git a/Cargo.toml b/Cargo.toml index 8e7d47a..e52d312 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,3 +37,4 @@ reqwest = { version = "0.10", default-features = false, features = [ "rustls-tls", "blocking", ] } +serde_qs = "0.13.0" diff --git a/src/main.rs b/src/main.rs index d1487df..11924da 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,10 @@ mod icon; use std::{collections::HashMap, fs, io::ErrorKind, path::PathBuf, str::FromStr}; use icon::extract_icon; -use mlua::{Lua, Table, Value as LuaValue}; +use mlua::{ + Lua, Table, + Value::{self as LuaValue}, +}; use once_cell::sync::{Lazy, OnceCell}; use axum::{ @@ -24,6 +27,8 @@ use tower_http::services::ServeDir; pub type AError = Box; pub type AResult = Result; +pub static CWD: Lazy = Lazy::new(|| std::env::current_dir().unwrap()); + pub static TERA: Lazy = Lazy::new(|| { let mut tera = Tera::default(); tera.add_raw_template( @@ -174,11 +179,30 @@ async fn main() -> AResult<()> { } async fn file_handler(request: Request) -> impl IntoResponse { - let base_dir = BASE_DIR.get_or_init(|| { + std::env::set_current_dir(&*CWD).unwrap(); + let base_dir: &PathBuf = BASE_DIR.get_or_init(|| { PathBuf::from_str(&std::env::var("BASE_DIR").unwrap_or(String::from("./public"))).unwrap() }); - if !request.uri().path().ends_with(".lua") { + let base_dir = base_dir.canonicalize().map_err(|e| { + dbg!(e); + (StatusCode::NOT_FOUND, "404: Not Found".to_string()) + })?; + + let mut request_uri_path = request.uri().path().to_owned(); + + if request_uri_path.ends_with("/") { + let mut a = base_dir.clone(); + a.push(".".to_owned() + &request_uri_path); + a.push("init.lua"); + if let Ok(b) = a.canonicalize() { + if b.exists() { + request_uri_path += "init.lua"; + } + } + } + + if !request_uri_path.ends_with(".lua") { return ServeDir::new(base_dir) .fallback(get(handler)) .try_call(request) @@ -202,7 +226,7 @@ async fn file_handler(request: Request) -> impl IntoResponse { }) .map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; let fetch_b64 = lua @@ -216,7 +240,7 @@ async fn file_handler(request: Request) -> impl IntoResponse { }) .map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; let dbg = lua @@ -226,27 +250,27 @@ async fn file_handler(request: Request) -> impl IntoResponse { }) .map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; globals.set("fetch", fetch_json).map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; globals.set("fetch_b64", fetch_b64).map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; globals.set("dbg", dbg).map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; let request_table = lua.create_table().map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; let headers_table = lua @@ -258,37 +282,59 @@ async fn file_handler(request: Request) -> impl IntoResponse { ) .map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; + let query = request.uri().query().unwrap_or(""); + + let query: HashMap = serde_qs::from_str(query) + .map_err(|e| { + println!("{:#?}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "500: Unable to parse query string".to_string(), + ) + }) + .unwrap_or(HashMap::new()); + + let query_table = lua.create_table_from(query).map_err(|e| { + eprintln!("Lua Error: {:?}", e); + render_lua_error(e) + })?; + request_table.set("headers", headers_table).map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; request_table .set("uri", request.uri().to_string()) .map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; request_table .set("method", request.method().to_string()) .map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; + request_table.set("query", query_table).map_err(|e| { + eprintln!("Lua Error: {:?}", e); + render_lua_error(e) + })?; + // Inject functions to change headers and such globals.set("request", request_table).map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; let response_table = lua.create_table().map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; response_table @@ -296,22 +342,22 @@ async fn file_handler(request: Request) -> impl IntoResponse { "headers", lua.create_table().map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?, ) .map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; globals.set("response", response_table).map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; let mut path = base_dir.clone(); - let uri = urlencoding::decode(&request.uri().path()[1..]).map_err(|e| { + let uri = urlencoding::decode(&request_uri_path[1..]).map_err(|e| { println!("{:?}", e); (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()) })?; @@ -327,7 +373,7 @@ async fn file_handler(request: Request) -> impl IntoResponse { (StatusCode::NOT_FOUND, "404: Not Found".to_string()) })?; - if !full_path.starts_with(full_base_path) { + if !full_path.starts_with(&full_base_path) { return Err((StatusCode::BAD_REQUEST, "400: Bad Request".to_string())); } @@ -336,16 +382,25 @@ async fn file_handler(request: Request) -> impl IntoResponse { (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()) })?; + // ------------------------------------------------------------------------ + + // TODO: FIX THIS, THIS IS DISGUSTING + std::env::set_current_dir(full_base_path).unwrap(); + let script = lua.load(script).eval::().map_err(|e| { + std::env::set_current_dir(&*CWD).unwrap(); eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; + std::env::set_current_dir(&*CWD).unwrap(); let result = lua.load("return response").eval::().map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; + // ------------------------------------------------------------------------ + let script = if let LuaValue::String(script) = script { script.to_string_lossy().to_string() } else { @@ -366,7 +421,7 @@ async fn file_handler(request: Request) -> impl IntoResponse { for pair in pairs { let (k, v) = pair.map_err(|e| { eprintln!("Lua Error: {:?}", e); - to_lua_error(e) + render_lua_error(e) })?; response.headers_mut().insert( HeaderName::from_str(&k).map_err(|e| { @@ -384,7 +439,7 @@ async fn file_handler(request: Request) -> impl IntoResponse { Ok(response) } -fn to_lua_error(e: mlua::Error) -> (StatusCode, String) { +fn render_lua_error(e: mlua::Error) -> (StatusCode, String) { let e = e .to_string() .split(':')