mod samplers; mod invoking_llms; mod db; mod config; use axum::{ extract::{Form, Query, State}, response::{Html, IntoResponse, Redirect, Response}, routing::{get, post}, Router, }; use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite}; use serde::Deserialize; use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; use std::sync::OnceLock; use tera::{Tera}; use tower_http::services::ServeDir; use samplers::random_junk::random_string; pub use config::{ConfigResult, GeneralServiceConfig, load_config, load_config_default}; pub use db::{DbResult, connect_db, init_database}; pub struct AppStateInner { pub config: GeneralServiceConfig, pub db: tokio_postgres::Client, } pub type AppState = std::sync::Arc; pub async fn init_app_state() -> DbResult { let config = load_config_default()?; let db = connect_db(&config).await?; Ok(std::sync::Arc::new(AppStateInner { config, db })) } static TEMPLATES: OnceLock = OnceLock::new(); fn templates() -> &'static Tera { TEMPLATES.get_or_init(|| Tera::new("src/pages/**/*.html").expect("load templates")) } struct AuthenticatedUserId { id: i32, name: String, } enum PasscodeAuthenticationResult{ WrongPassword, Ok(AuthenticatedUserId) } enum WebsiteAuthenticationResult { NoCookie, SomeCookie(PasscodeAuthenticationResult), } async fn website_authentication_with_passcode( state: &AppStateInner, passcode: &str) -> PasscodeAuthenticationResult { let row = state.db.query_opt( "SELECT id, name FROM public.person WHERE passcode = $1 LIMIT 1", &[&passcode], ).await; match row { Ok(Some(row)) => PasscodeAuthenticationResult::Ok(AuthenticatedUserId { id: row.get(0), name: row.get(1), }), Ok(None) => PasscodeAuthenticationResult::WrongPassword, Err(err) => { eprintln!("auth query failed: {err}"); PasscodeAuthenticationResult::WrongPassword } } } async fn website_authentication_with_cookie( state: &AppStateInner, jar: &CookieJar, ) -> WebsiteAuthenticationResult { let auth_cookie = jar.get("auth"); match auth_cookie { None => WebsiteAuthenticationResult::NoCookie, Some(x) => { let passcode = x.value(); WebsiteAuthenticationResult::SomeCookie(website_authentication_with_passcode(state, passcode).await) }, } } fn axum_handler_with_auth( handler: H, state: AppState, ) -> impl Fn(CookieJar, T) -> std::pin::Pin + Send>> + Clone + Send + 'static where H: for<'a> Fn(T, &'a AppStateInner, AuthenticatedUserId) -> std::pin::Pin + Send + 'a>> + Clone + Send + 'static, Res: IntoResponse + 'static, T: Send + 'static, { move |jar: CookieJar, args: T| { let state = state.clone(); let handler = handler.clone(); Box::pin(async move { let res = website_authentication_with_cookie(state.as_ref(), &jar).await; match res { WebsiteAuthenticationResult::SomeCookie(PasscodeAuthenticationResult::Ok(user)) => { handler(args, state.as_ref(), user).await.into_response() } WebsiteAuthenticationResult::NoCookie => { Redirect::to("/login").into_response() } WebsiteAuthenticationResult::SomeCookie(PasscodeAuthenticationResult::WrongPassword) => { Redirect::to("/login?error=cookie").into_response() } } }) } } #[derive(Deserialize)] struct LoginPageQuery { error: Option, } #[derive(Deserialize)] struct LoginPageForm { passcode: String, csrf_token: String, } async fn login_get( State(state): State, jar: CookieJar, Query(query): Query, ) -> (CookieJar, Html) { let cur_auth: WebsiteAuthenticationResult = website_authentication_with_cookie(state.as_ref(), &jar).await; let csrf = random_string(32); let jar = jar.add( Cookie::build(("csrf", csrf.clone())) .path("/") .http_only(true) .same_site(SameSite::Strict) .build(), ); let error = match query.error.as_deref() { Some("cookie") => Some("Incorrect session cookie"), Some("password") => Some("Invalid passcode"), Some("csrf") => Some("Implicit log in attempt aborted!"), _ => None, }; let mut ctx = tera::Context::new(); ctx.insert("csrf", &csrf); if let Some(msg) = error { ctx.insert("error", msg); } if let WebsiteAuthenticationResult::SomeCookie(PasscodeAuthenticationResult::Ok(cur_user)) = cur_auth { ctx.insert("logged_in_cur_user", &cur_user.name); } let body = templates() .render("login.html", &ctx) .expect("render index"); (jar, Html(body)) } async fn login_post( State(state): State, jar: CookieJar, Form(form): Form ) -> impl IntoResponse { let csrf_ok = jar .get("csrf") .map(|cookie| cookie.value()) == Some(form.csrf_token.as_str()); if !csrf_ok { return Redirect::to("/login?error=csrf").into_response(); } let res = website_authentication_with_passcode(&state, &form.passcode); match res.await { PasscodeAuthenticationResult::Ok(_) => { let jar = jar.add( Cookie::build(("auth", form.passcode)) .path("/") .http_only(true) .same_site(SameSite::Strict) .build(), ); (jar, Redirect::to("/")).into_response() } PasscodeAuthenticationResult::WrongPassword => Redirect::to("/login?error=password").into_response() } } fn index( _: (), _state: &AppStateInner, _user: AuthenticatedUserId, ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move { let body = templates() .render("index.html", &tera::Context::new()) .expect("render index"); Html(body) }) } fn welcome( _: (), _state: &AppStateInner, _user: AuthenticatedUserId, ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move { let body = templates() .render("welcome.html", &tera::Context::new()) .expect("render welcome"); Html(body) }) } pub async fn run_server() -> Result<(), Box> { let state = init_app_state().await.expect("lol"); let app = Router::new() .route("/login", get(login_get)) .route("/login", post(login_post)) .route("/welcome", get(axum_handler_with_auth(welcome, state.clone()))) .route("/", get(axum_handler_with_auth(index, state.clone()))) .nest_service("/static", ServeDir::new("src/static")) .with_state(state); let addr: SocketAddr = "127.0.0.1:3000".parse().expect("valid socket addr"); println!("listening on http://{addr}"); let listener = tokio::net::TcpListener::bind(addr) .await .expect("bind failed"); axum::serve(listener, app).await?; Ok(()) } pub async fn init_db() -> Result<(), Box> { let config = load_config_default()?; init_database(&config).await?; Ok(()) }