diff --git a/src/config.rs b/src/config.rs index ac269e6..38f207e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,8 +4,6 @@ use std::env; use std::net::SocketAddr; use std::str::FromStr; -use tracing::error; - #[derive(Debug)] pub struct Config { pub bind_address: SocketAddr, diff --git a/src/main.rs b/src/main.rs index 6bb675b..1cfae13 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,5 @@ // src/main.rs -use std::net::SocketAddr; use std::process::exit; use axum::{Router, routing::get}; diff --git a/src/utils/crypto/crypto.rs b/src/utils/crypto/crypto.rs new file mode 100644 index 0000000..3a79bf7 --- /dev/null +++ b/src/utils/crypto/crypto.rs @@ -0,0 +1,167 @@ +use once_cell::sync::Lazy; +use openssl::error::ErrorStack; +use openssl::pkey::PKey; +use openssl::rsa::{Padding, Rsa}; +use openssl::symm::{Cipher, Crypter, Mode}; +use sha2::{Digest, Sha256}; +use std::fs; +use std::io; +use std::path::{Path, PathBuf}; + +// --- Error Handling --- + +#[derive(Debug)] +pub enum CryptoError { + OpenSsl(ErrorStack), + Io(io::Error), + Hex(hex::FromHexError), + Utf8(std::string::FromUtf8Error), + Custom(String), +} + +impl From for CryptoError { + fn from(err: ErrorStack) -> Self { + CryptoError::OpenSsl(err) + } +} + +impl From for CryptoError { + fn from(err: io::Error) -> Self { + CryptoError::Io(err) + } +} + +impl From for CryptoError { + fn from(err: hex::FromHexError) -> Self { + CryptoError::Hex(err) + } +} + +impl From for CryptoError { + fn from(err: std::string::FromUtf8Error) -> Self { + CryptoError::Utf8(err) + } +} + +impl std::fmt::Display for CryptoError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + CryptoError::OpenSsl(e) => write!(f, "OpenSSL error: {}", e), + CryptoError::Io(e) => write!(f, "IO error: {}", e), + CryptoError::Hex(e) => write!(f, "Hex decoding error: {}", e), + CryptoError::Utf8(e) => write!(f, "UTF-8 conversion error: {}", e), + CryptoError::Custom(s) => write!(f, "Crypto error: {}", s), + } + } +} + +impl std::error::Error for CryptoError {} + +// --- KeyPair Structure --- + +#[derive(Debug, Clone)] +pub struct KeyPair { + pub private_key: String, + pub public_key: String, +} + +// --- Crypto Utility --- + +pub struct CryptoUtils; + +static KEY: Lazy<[u8; 32]> = Lazy::new(|| { + let mut hasher = Sha256::new(); + hasher.update(b"z&R*3mN@wS5gY!8c*P#L5bQm&8wT3vNxE!UW4ex7HJKLfghRT"); + hasher.finalize().into() +}); + +const IV: &[u8; 16] = b"6234567890123456"; + +impl CryptoUtils { + /// Encrypts a string using AES-256-CBC. + pub fn encrypt(secret: &str) -> Result { + let cipher = Cipher::aes_256_cbc(); + let data = secret.as_bytes(); + + let mut encrypter = Crypter::new(cipher, Mode::Encrypt, &KEY, Some(IV))?; + encrypter.pad(true); + + let mut encrypted = vec![0; data.len() + cipher.block_size()]; + let count = encrypter.update(data, &mut encrypted)?; + let rest = encrypter.finalize(&mut encrypted[count..])?; + + encrypted.truncate(count + rest); + Ok(hex::encode(encrypted)) + } + + /// Decrypts a string using AES-256-CBC. + pub fn decrypt(encrypted_secret: &str) -> Result { + let cipher = Cipher::aes_256_cbc(); + let data = hex::decode(encrypted_secret)?; + + let mut decrypter = Crypter::new(cipher, Mode::Decrypt, &KEY, Some(IV))?; + decrypter.pad(true); + + let mut decrypted = vec![0; data.len() + cipher.block_size()]; + let count = decrypter.update(&data, &mut decrypted)?; + let rest = decrypter.finalize(&mut decrypted[count..])?; + + decrypted.truncate(count + rest); + Ok(String::from_utf8(decrypted)?) + } + + /// Generates a 4096-bit RSA key pair in PEM format. + pub fn generate_key_pair() -> Result { + let rsa = Rsa::generate(4096)?; + let pkey = PKey::from_rsa(rsa)?; + + let private_key = pkey + .private_key_to_pem_pkcs8()? + .iter() + .map(|&c| c as char) + .collect(); + let public_key = pkey + .public_key_to_pem()? + .iter() + .map(|&c| c as char) + .collect(); + + Ok(KeyPair { + private_key, + public_key, + }) + } + + /// Saves the given KeyPair to files in the specified directory. + pub fn save_keys_to_files(keys: &KeyPair, directory: &Path) -> Result<(), CryptoError> { + fs::create_dir_all(directory)?; + fs::write(directory.join("private.pem"), &keys.private_key)?; + fs::write(directory.join("public.pem"), &keys.public_key)?; + Ok(()) + } + + /// Loads a KeyPair from files in the specified directory. + pub fn load_keys_from_files(directory: &Path) -> Result { + let private_key = fs::read_to_string(directory.join("private.pem"))?; + let public_key = fs::read_to_string(directory.join("public.pem"))?; + Ok(KeyPair { + private_key, + public_key, + }) + } + + /// Initializes RSA key pair. + /// If keys exist in the default 'keys' directory, they are loaded. + /// Otherwise, new keys are generated and saved. + pub fn init_keys() -> Result { + let key_path = PathBuf::from("keys"); + + if key_path.join("private.pem").exists() && key_path.join("public.pem").exists() { + Self::load_keys_from_files(&key_path) + } else { + let keys = Self::generate_key_pair()?; + Self::save_keys_to_files(&keys, &key_path)?; + Ok(keys) + } + } +} diff --git a/src/utils/jwt/jwt.rs b/src/utils/jwt/jwt.rs new file mode 100644 index 0000000..b797d71 --- /dev/null +++ b/src/utils/jwt/jwt.rs @@ -0,0 +1,320 @@ +// Add the following dependencies to your Cargo.toml file: +// openssl = "0.10" +// serde = { version = "1.0", features = ["derive"] } +// serde_json = "1.0" +// chrono = "0.4" +// base64 = "0.21" + +use crate::utils::crypto::crypto::CryptoUtils; +use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; +use chrono::{Duration, Utc}; +use openssl::{ + hash::MessageDigest, + pkey::{PKey, Private}, + rsa::Rsa, + sign::{Signer, Verifier}, +}; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, from_str, from_value, to_string}; +use std::collections::BTreeMap; +use std::env; + +// --- Error Handling --- + +#[derive(Debug)] +pub enum JWTError { + OpenSsl(openssl::error::ErrorStack), + SerdeJson(serde_json::Error), + Base64(base64::DecodeError), + Crypto(String), + InvalidTokenFormat(String), + Validation(String), + Custom(String), +} + +impl std::fmt::Display for JWTError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + JWTError::OpenSsl(e) => write!(f, "OpenSSL error: {}", e), + JWTError::SerdeJson(e) => write!(f, "JSON serialization error: {}", e), + JWTError::Base64(e) => write!(f, "Base64 decoding error: {}", e), + JWTError::Crypto(s) => write!(f, "Crypto error: {}", s), + JWTError::InvalidTokenFormat(s) => write!(f, "Invalid token format: {}", s), + JWTError::Validation(s) => write!(f, "Token validation failed: {}", s), + JWTError::Custom(s) => write!(f, "JWT error: {}", s), + } + } +} + +impl std::error::Error for JWTError {} + +impl From for JWTError { + fn from(err: openssl::error::ErrorStack) -> JWTError { + JWTError::OpenSsl(err) + } +} + +impl From for JWTError { + fn from(err: serde_json::Error) -> JWTError { + JWTError::SerdeJson(err) + } +} + +impl From for JWTError { + fn from(err: base64::DecodeError) -> JWTError { + JWTError::Base64(err) + } +} + +// --- Structures --- + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct JWTHeader { + alg: String, + typ: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct JWTOptions { + pub algorithm: String, + pub expires_in: i64, // seconds + pub issuer: String, +} + +impl Default for JWTOptions { + fn default() -> Self { + let expires_in_str = env::var("JWT_EXPIRES_IN").unwrap_or_else(|_| "3600".to_string()); + let expires_in = expires_in_str.parse::().unwrap_or(3600); + let issuer = env::var("JWT_ISSUER").unwrap_or_default(); + + JWTOptions { + algorithm: "RS256".to_string(), + expires_in, + issuer, + } + } +} + +pub struct JWTUtils { + payload: Value, + private_key: String, + public_key: String, + options: JWTOptions, +} + +// --- Helper Functions --- + +fn base64url_encode>(input: T) -> String { + URL_SAFE_NO_PAD.encode(input) +} + +fn base64url_decode(input: &str) -> Result, base64::DecodeError> { + URL_SAFE_NO_PAD.decode(input) +} + +// --- Implementation --- + +impl JWTUtils { + pub fn new( + payload: Value, + private_key: Option, + public_key: Option, + options: Option, + ) -> Result { + let opts = options.unwrap_or_default(); + + let (priv_key, pub_key) = match (private_key, public_key) { + (Some(priv_k), Some(pub_k)) => (priv_k, pub_k), + (priv_k, pub_k) => { + let keys = CryptoUtils::load_keys_from_files("keys") + .map_err(|e| JWTError::Crypto(e.to_string()))?; + ( + priv_k.unwrap_or(keys.private_key), + pub_k.unwrap_or(keys.public_key), + ) + } + }; + + Ok(JWTUtils { + payload, + private_key: priv_key, + public_key: pub_key, + options: opts, + }) + } + + /// Create JWT header + fn create_header(&self) -> JWTHeader { + JWTHeader { + alg: self.options.algorithm.clone(), + typ: "JWT".to_string(), + } + } + + /// Process payload with standard claims + fn process_payload(&self) -> Result { + let mut payload_obj = self + .payload + .as_object() + .ok_or_else(|| JWTError::Custom("Payload must be a JSON object".to_string()))? + .clone(); + + let now = Utc::now(); + let iat = now.timestamp(); + let exp = (now + Duration::seconds(self.options.expires_in)).timestamp(); + + payload_obj.insert("iat".to_string(), iat.into()); + payload_obj.insert("exp".to_string(), exp.into()); + payload_obj.insert("iss".to_string(), self.options.issuer.clone().into()); + + Ok(to_string(&payload_obj)?) + } + + /// Sign the JWT components + fn sign(&self, header_base64: &str, payload_base64: &str) -> Result { + let signature_input = format!("{}.{}", header_base64, payload_base64); + let keypair = Rsa::private_key_from_pem(self.private_key.as_bytes())?; + let pkey = PKey::from_rsa(keypair)?; + + let mut signer = Signer::new(MessageDigest::sha256(), &pkey)?; + signer.update(signature_input.as_bytes())?; + let signature = signer.sign_to_vec()?; + + Ok(base64url_encode(&signature)) + } + + /// Create complete JWT + pub fn create_token(&self) -> Result { + let header = self.create_header(); + let processed_payload = self.process_payload()?; + + let header_base64 = base64url_encode(to_string(&header)?); + let payload_base64 = base64url_encode(processed_payload); + let signature = self.sign(&header_base64, &payload_base64)?; + + Ok(format!( + "{}.{}.{}", + header_base64, payload_base64, signature + )) + } + + /// Verify JWT token signature + pub fn verify(&self, token: &str) -> bool { + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return false; + } + + let signature_input = format!("{}.{}", parts[0], parts[1]); + let signature = match base64url_decode(parts[2]) { + Ok(s) => s, + Err(_) => return false, + }; + + let key = match PKey::public_key_from_pem(self.public_key.as_bytes()) { + Ok(k) => k, + Err(_) => return false, + }; + + let mut verifier = match Verifier::new(MessageDigest::sha256(), &key) { + Ok(v) => v, + Err(_) => return false, + }; + verifier.update(signature_input.as_bytes()).unwrap(); + + verifier.verify(&signature).unwrap_or(false) + } + + /// Decode JWT token without verification + pub fn decode(token: &str) -> Result<(Value, Value), JWTError> { + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() < 2 { + return Err(JWTError::InvalidTokenFormat( + "Token must have at least two parts".to_string(), + )); + } + + let header_json = String::from_utf8(base64url_decode(parts[0])?) + .map_err(|_| JWTError::InvalidTokenFormat("Header is not valid UTF-8".to_string()))?; + let payload_json = String::from_utf8(base64url_decode(parts[1])?) + .map_err(|_| JWTError::InvalidTokenFormat("Payload is not valid UTF-8".to_string()))?; + + let header: Value = from_str(&header_json)?; + let payload: Value = from_str(&payload_json)?; + + Ok((header, payload)) + } + + /// Check if token is expired + pub fn is_expired(token: &str) -> bool { + match Self::decode(token) { + Ok((_, payload)) => { + if let Some(exp) = payload.get("exp").and_then(|v| v.as_i64()) { + let now = Utc::now().timestamp(); + exp < now + } else { + true // No expiration claim, consider it expired for safety + } + } + Err(_) => true, // Invalid token, consider it expired + } + } + + /// A combined decode and verify function + pub fn decode_and_verify(token: &str) -> Result<(Value, Value), JWTError> { + let jwt = JWTUtils::new(Value::Null, None, None, None)?; + + if !jwt.verify(token) { + return Err(JWTError::Validation( + "Signature verification failed".to_string(), + )); + } + + if Self::is_expired(token) { + return Err(JWTError::Validation("Token has expired".to_string())); + } + + Self::decode(token) + } + + /// Refresh a token + pub fn refresh_token(old_token: &str) -> Result { + let (_, payload_val) = Self::decode(old_token)?; + let mut payload_obj = payload_val + .as_object() + .ok_or_else(|| JWTError::Custom("Payload is not an object".to_string()))? + .clone(); + + payload_obj.remove("exp"); + payload_obj.remove("iat"); + + // The keys must be loaded from files for this static method + let keys = CryptoUtils::load_keys_from_files("keys") + .map_err(|e| JWTError::Crypto(e.to_string()))?; + + let jwt = JWTUtils::new( + Value::Object(payload_obj), + Some(keys.private_key), + Some(keys.public_key), + None, + )?; + + jwt.create_token() + } + + /// Validate that specific claims are present in the token + pub fn validate_claims(token: &str, required_claims: &[&str]) -> bool { + match Self::decode(token) { + Ok((_, payload)) => { + if let Some(payload_obj) = payload.as_object() { + required_claims + .iter() + .all(|claim| payload_obj.contains_key(*claim)) + } else { + false + } + } + Err(_) => false, + } + } +}