// Add the following dependencies to your Cargo.toml file: // openssl = "0.10" // serde = { version = "1.0", features = ["derive"] } // serde_json = "1.0" // chrono = "0.4" // base64 = "0.21" use crate::utils::crypto::crypto::CryptoUtils; use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; use chrono::{Duration, Utc}; use openssl::{ hash::MessageDigest, pkey::{PKey, Private}, rsa::Rsa, sign::{Signer, Verifier}, }; use serde::{Deserialize, Serialize}; use serde_json::{Value, from_str, from_value, to_string}; use std::collections::BTreeMap; use std::env; // --- Error Handling --- #[derive(Debug)] pub enum JWTError { OpenSsl(openssl::error::ErrorStack), SerdeJson(serde_json::Error), Base64(base64::DecodeError), Crypto(String), InvalidTokenFormat(String), Validation(String), Custom(String), } impl std::fmt::Display for JWTError { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { JWTError::OpenSsl(e) => write!(f, "OpenSSL error: {}", e), JWTError::SerdeJson(e) => write!(f, "JSON serialization error: {}", e), JWTError::Base64(e) => write!(f, "Base64 decoding error: {}", e), JWTError::Crypto(s) => write!(f, "Crypto error: {}", s), JWTError::InvalidTokenFormat(s) => write!(f, "Invalid token format: {}", s), JWTError::Validation(s) => write!(f, "Token validation failed: {}", s), JWTError::Custom(s) => write!(f, "JWT error: {}", s), } } } impl std::error::Error for JWTError {} impl From for JWTError { fn from(err: openssl::error::ErrorStack) -> JWTError { JWTError::OpenSsl(err) } } impl From for JWTError { fn from(err: serde_json::Error) -> JWTError { JWTError::SerdeJson(err) } } impl From for JWTError { fn from(err: base64::DecodeError) -> JWTError { JWTError::Base64(err) } } // --- Structures --- #[derive(Debug, Serialize, Deserialize, Clone)] struct JWTHeader { alg: String, typ: String, } #[derive(Debug, Serialize, Deserialize)] pub struct JWTOptions { pub algorithm: String, pub expires_in: i64, // seconds pub issuer: String, } impl Default for JWTOptions { fn default() -> Self { let expires_in_str = env::var("JWT_EXPIRES_IN").unwrap_or_else(|_| "3600".to_string()); let expires_in = expires_in_str.parse::().unwrap_or(3600); let issuer = env::var("JWT_ISSUER").unwrap_or_default(); JWTOptions { algorithm: "RS256".to_string(), expires_in, issuer, } } } pub struct JWTUtils { payload: Value, private_key: String, public_key: String, options: JWTOptions, } // --- Helper Functions --- fn base64url_encode>(input: T) -> String { URL_SAFE_NO_PAD.encode(input) } fn base64url_decode(input: &str) -> Result, base64::DecodeError> { URL_SAFE_NO_PAD.decode(input) } // --- Implementation --- impl JWTUtils { pub fn new( payload: Value, private_key: Option, public_key: Option, options: Option, ) -> Result { let opts = options.unwrap_or_default(); let (priv_key, pub_key) = match (private_key, public_key) { (Some(priv_k), Some(pub_k)) => (priv_k, pub_k), (priv_k, pub_k) => { let keys = CryptoUtils::load_keys_from_files("keys") .map_err(|e| JWTError::Crypto(e.to_string()))?; ( priv_k.unwrap_or(keys.private_key), pub_k.unwrap_or(keys.public_key), ) } }; Ok(JWTUtils { payload, private_key: priv_key, public_key: pub_key, options: opts, }) } /// Create JWT header fn create_header(&self) -> JWTHeader { JWTHeader { alg: self.options.algorithm.clone(), typ: "JWT".to_string(), } } /// Process payload with standard claims fn process_payload(&self) -> Result { let mut payload_obj = self .payload .as_object() .ok_or_else(|| JWTError::Custom("Payload must be a JSON object".to_string()))? .clone(); let now = Utc::now(); let iat = now.timestamp(); let exp = (now + Duration::seconds(self.options.expires_in)).timestamp(); payload_obj.insert("iat".to_string(), iat.into()); payload_obj.insert("exp".to_string(), exp.into()); payload_obj.insert("iss".to_string(), self.options.issuer.clone().into()); Ok(to_string(&payload_obj)?) } /// Sign the JWT components fn sign(&self, header_base64: &str, payload_base64: &str) -> Result { let signature_input = format!("{}.{}", header_base64, payload_base64); let keypair = Rsa::private_key_from_pem(self.private_key.as_bytes())?; let pkey = PKey::from_rsa(keypair)?; let mut signer = Signer::new(MessageDigest::sha256(), &pkey)?; signer.update(signature_input.as_bytes())?; let signature = signer.sign_to_vec()?; Ok(base64url_encode(&signature)) } /// Create complete JWT pub fn create_token(&self) -> Result { let header = self.create_header(); let processed_payload = self.process_payload()?; let header_base64 = base64url_encode(to_string(&header)?); let payload_base64 = base64url_encode(processed_payload); let signature = self.sign(&header_base64, &payload_base64)?; Ok(format!( "{}.{}.{}", header_base64, payload_base64, signature )) } /// Verify JWT token signature pub fn verify(&self, token: &str) -> bool { let parts: Vec<&str> = token.split('.').collect(); if parts.len() != 3 { return false; } let signature_input = format!("{}.{}", parts[0], parts[1]); let signature = match base64url_decode(parts[2]) { Ok(s) => s, Err(_) => return false, }; let key = match PKey::public_key_from_pem(self.public_key.as_bytes()) { Ok(k) => k, Err(_) => return false, }; let mut verifier = match Verifier::new(MessageDigest::sha256(), &key) { Ok(v) => v, Err(_) => return false, }; verifier.update(signature_input.as_bytes()).unwrap(); verifier.verify(&signature).unwrap_or(false) } /// Decode JWT token without verification pub fn decode(token: &str) -> Result<(Value, Value), JWTError> { let parts: Vec<&str> = token.split('.').collect(); if parts.len() < 2 { return Err(JWTError::InvalidTokenFormat( "Token must have at least two parts".to_string(), )); } let header_json = String::from_utf8(base64url_decode(parts[0])?) .map_err(|_| JWTError::InvalidTokenFormat("Header is not valid UTF-8".to_string()))?; let payload_json = String::from_utf8(base64url_decode(parts[1])?) .map_err(|_| JWTError::InvalidTokenFormat("Payload is not valid UTF-8".to_string()))?; let header: Value = from_str(&header_json)?; let payload: Value = from_str(&payload_json)?; Ok((header, payload)) } /// Check if token is expired pub fn is_expired(token: &str) -> bool { match Self::decode(token) { Ok((_, payload)) => { if let Some(exp) = payload.get("exp").and_then(|v| v.as_i64()) { let now = Utc::now().timestamp(); exp < now } else { true // No expiration claim, consider it expired for safety } } Err(_) => true, // Invalid token, consider it expired } } /// A combined decode and verify function pub fn decode_and_verify(token: &str) -> Result<(Value, Value), JWTError> { let jwt = JWTUtils::new(Value::Null, None, None, None)?; if !jwt.verify(token) { return Err(JWTError::Validation( "Signature verification failed".to_string(), )); } if Self::is_expired(token) { return Err(JWTError::Validation("Token has expired".to_string())); } Self::decode(token) } /// Refresh a token pub fn refresh_token(old_token: &str) -> Result { let (_, payload_val) = Self::decode(old_token)?; let mut payload_obj = payload_val .as_object() .ok_or_else(|| JWTError::Custom("Payload is not an object".to_string()))? .clone(); payload_obj.remove("exp"); payload_obj.remove("iat"); // The keys must be loaded from files for this static method let keys = CryptoUtils::load_keys_from_files("keys") .map_err(|e| JWTError::Crypto(e.to_string()))?; let jwt = JWTUtils::new( Value::Object(payload_obj), Some(keys.private_key), Some(keys.public_key), None, )?; jwt.create_token() } /// Validate that specific claims are present in the token pub fn validate_claims(token: &str, required_claims: &[&str]) -> bool { match Self::decode(token) { Ok((_, payload)) => { if let Some(payload_obj) = payload.as_object() { required_claims .iter() .all(|claim| payload_obj.contains_key(*claim)) } else { false } } Err(_) => false, } } }