feat: Implement Crypto and JWT utility modules in Rust
This commit introduces a Rust implementation of the cryptographic and JWT handling utilities, translated from the original TypeScript codebase. The new `CryptoUtils` module provides core cryptographic functionalities, including: - AES-256-CBC encryption and decryption. - Generation, saving, and loading of RSA-4096 key pairs. - It leverages the `openssl`, `sha2`, and `hex` crates. The new `JWTUtils` module handles JSON Web Tokens manually, without relying on the `jsonwebtoken` crate. Its features include: - Creating and signing JWTs using RSA-SHA256. - Verifying the signature and expiration of tokens. - Decoding tokens and validating claims. - This implementation uses the `openssl` crate for signing and verification, ensuring alignment with the `CryptoUtils` module. Additionally, minor compiler warnings, such as unused imports in `main.rs` and `config.rs`, have been resolved.
This commit is contained in:
@@ -4,8 +4,6 @@ use std::env;
|
||||
use std::net::SocketAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
use tracing::error;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
pub bind_address: SocketAddr,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// src/main.rs
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::process::exit;
|
||||
|
||||
use axum::{Router, routing::get};
|
||||
|
||||
167
src/utils/crypto/crypto.rs
Normal file
167
src/utils/crypto/crypto.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
use once_cell::sync::Lazy;
|
||||
use openssl::error::ErrorStack;
|
||||
use openssl::pkey::PKey;
|
||||
use openssl::rsa::{Padding, Rsa};
|
||||
use openssl::symm::{Cipher, Crypter, Mode};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
// --- Error Handling ---
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CryptoError {
|
||||
OpenSsl(ErrorStack),
|
||||
Io(io::Error),
|
||||
Hex(hex::FromHexError),
|
||||
Utf8(std::string::FromUtf8Error),
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl From<ErrorStack> for CryptoError {
|
||||
fn from(err: ErrorStack) -> Self {
|
||||
CryptoError::OpenSsl(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for CryptoError {
|
||||
fn from(err: io::Error) -> Self {
|
||||
CryptoError::Io(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<hex::FromHexError> for CryptoError {
|
||||
fn from(err: hex::FromHexError) -> Self {
|
||||
CryptoError::Hex(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::string::FromUtf8Error> for CryptoError {
|
||||
fn from(err: std::string::FromUtf8Error) -> Self {
|
||||
CryptoError::Utf8(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CryptoError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
CryptoError::OpenSsl(e) => write!(f, "OpenSSL error: {}", e),
|
||||
CryptoError::Io(e) => write!(f, "IO error: {}", e),
|
||||
CryptoError::Hex(e) => write!(f, "Hex decoding error: {}", e),
|
||||
CryptoError::Utf8(e) => write!(f, "UTF-8 conversion error: {}", e),
|
||||
CryptoError::Custom(s) => write!(f, "Crypto error: {}", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for CryptoError {}
|
||||
|
||||
// --- KeyPair Structure ---
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KeyPair {
|
||||
pub private_key: String,
|
||||
pub public_key: String,
|
||||
}
|
||||
|
||||
// --- Crypto Utility ---
|
||||
|
||||
pub struct CryptoUtils;
|
||||
|
||||
static KEY: Lazy<[u8; 32]> = Lazy::new(|| {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(b"z&R*3mN@wS5gY!8c*P#L5bQm&8wT3vNxE!UW4ex7HJKLfghRT");
|
||||
hasher.finalize().into()
|
||||
});
|
||||
|
||||
const IV: &[u8; 16] = b"6234567890123456";
|
||||
|
||||
impl CryptoUtils {
|
||||
/// Encrypts a string using AES-256-CBC.
|
||||
pub fn encrypt(secret: &str) -> Result<String, CryptoError> {
|
||||
let cipher = Cipher::aes_256_cbc();
|
||||
let data = secret.as_bytes();
|
||||
|
||||
let mut encrypter = Crypter::new(cipher, Mode::Encrypt, &KEY, Some(IV))?;
|
||||
encrypter.pad(true);
|
||||
|
||||
let mut encrypted = vec![0; data.len() + cipher.block_size()];
|
||||
let count = encrypter.update(data, &mut encrypted)?;
|
||||
let rest = encrypter.finalize(&mut encrypted[count..])?;
|
||||
|
||||
encrypted.truncate(count + rest);
|
||||
Ok(hex::encode(encrypted))
|
||||
}
|
||||
|
||||
/// Decrypts a string using AES-256-CBC.
|
||||
pub fn decrypt(encrypted_secret: &str) -> Result<String, CryptoError> {
|
||||
let cipher = Cipher::aes_256_cbc();
|
||||
let data = hex::decode(encrypted_secret)?;
|
||||
|
||||
let mut decrypter = Crypter::new(cipher, Mode::Decrypt, &KEY, Some(IV))?;
|
||||
decrypter.pad(true);
|
||||
|
||||
let mut decrypted = vec![0; data.len() + cipher.block_size()];
|
||||
let count = decrypter.update(&data, &mut decrypted)?;
|
||||
let rest = decrypter.finalize(&mut decrypted[count..])?;
|
||||
|
||||
decrypted.truncate(count + rest);
|
||||
Ok(String::from_utf8(decrypted)?)
|
||||
}
|
||||
|
||||
/// Generates a 4096-bit RSA key pair in PEM format.
|
||||
pub fn generate_key_pair() -> Result<KeyPair, CryptoError> {
|
||||
let rsa = Rsa::generate(4096)?;
|
||||
let pkey = PKey::from_rsa(rsa)?;
|
||||
|
||||
let private_key = pkey
|
||||
.private_key_to_pem_pkcs8()?
|
||||
.iter()
|
||||
.map(|&c| c as char)
|
||||
.collect();
|
||||
let public_key = pkey
|
||||
.public_key_to_pem()?
|
||||
.iter()
|
||||
.map(|&c| c as char)
|
||||
.collect();
|
||||
|
||||
Ok(KeyPair {
|
||||
private_key,
|
||||
public_key,
|
||||
})
|
||||
}
|
||||
|
||||
/// Saves the given KeyPair to files in the specified directory.
|
||||
pub fn save_keys_to_files(keys: &KeyPair, directory: &Path) -> Result<(), CryptoError> {
|
||||
fs::create_dir_all(directory)?;
|
||||
fs::write(directory.join("private.pem"), &keys.private_key)?;
|
||||
fs::write(directory.join("public.pem"), &keys.public_key)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Loads a KeyPair from files in the specified directory.
|
||||
pub fn load_keys_from_files(directory: &Path) -> Result<KeyPair, CryptoError> {
|
||||
let private_key = fs::read_to_string(directory.join("private.pem"))?;
|
||||
let public_key = fs::read_to_string(directory.join("public.pem"))?;
|
||||
Ok(KeyPair {
|
||||
private_key,
|
||||
public_key,
|
||||
})
|
||||
}
|
||||
|
||||
/// Initializes RSA key pair.
|
||||
/// If keys exist in the default 'keys' directory, they are loaded.
|
||||
/// Otherwise, new keys are generated and saved.
|
||||
pub fn init_keys() -> Result<KeyPair, CryptoError> {
|
||||
let key_path = PathBuf::from("keys");
|
||||
|
||||
if key_path.join("private.pem").exists() && key_path.join("public.pem").exists() {
|
||||
Self::load_keys_from_files(&key_path)
|
||||
} else {
|
||||
let keys = Self::generate_key_pair()?;
|
||||
Self::save_keys_to_files(&keys, &key_path)?;
|
||||
Ok(keys)
|
||||
}
|
||||
}
|
||||
}
|
||||
320
src/utils/jwt/jwt.rs
Normal file
320
src/utils/jwt/jwt.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
// Add the following dependencies to your Cargo.toml file:
|
||||
// openssl = "0.10"
|
||||
// serde = { version = "1.0", features = ["derive"] }
|
||||
// serde_json = "1.0"
|
||||
// chrono = "0.4"
|
||||
// base64 = "0.21"
|
||||
|
||||
use crate::utils::crypto::crypto::CryptoUtils;
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
||||
use chrono::{Duration, Utc};
|
||||
use openssl::{
|
||||
hash::MessageDigest,
|
||||
pkey::{PKey, Private},
|
||||
rsa::Rsa,
|
||||
sign::{Signer, Verifier},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, from_str, from_value, to_string};
|
||||
use std::collections::BTreeMap;
|
||||
use std::env;
|
||||
|
||||
// --- Error Handling ---
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum JWTError {
|
||||
OpenSsl(openssl::error::ErrorStack),
|
||||
SerdeJson(serde_json::Error),
|
||||
Base64(base64::DecodeError),
|
||||
Crypto(String),
|
||||
InvalidTokenFormat(String),
|
||||
Validation(String),
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for JWTError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
JWTError::OpenSsl(e) => write!(f, "OpenSSL error: {}", e),
|
||||
JWTError::SerdeJson(e) => write!(f, "JSON serialization error: {}", e),
|
||||
JWTError::Base64(e) => write!(f, "Base64 decoding error: {}", e),
|
||||
JWTError::Crypto(s) => write!(f, "Crypto error: {}", s),
|
||||
JWTError::InvalidTokenFormat(s) => write!(f, "Invalid token format: {}", s),
|
||||
JWTError::Validation(s) => write!(f, "Token validation failed: {}", s),
|
||||
JWTError::Custom(s) => write!(f, "JWT error: {}", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for JWTError {}
|
||||
|
||||
impl From<openssl::error::ErrorStack> for JWTError {
|
||||
fn from(err: openssl::error::ErrorStack) -> JWTError {
|
||||
JWTError::OpenSsl(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for JWTError {
|
||||
fn from(err: serde_json::Error) -> JWTError {
|
||||
JWTError::SerdeJson(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<base64::DecodeError> for JWTError {
|
||||
fn from(err: base64::DecodeError) -> JWTError {
|
||||
JWTError::Base64(err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Structures ---
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
struct JWTHeader {
|
||||
alg: String,
|
||||
typ: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct JWTOptions {
|
||||
pub algorithm: String,
|
||||
pub expires_in: i64, // seconds
|
||||
pub issuer: String,
|
||||
}
|
||||
|
||||
impl Default for JWTOptions {
|
||||
fn default() -> Self {
|
||||
let expires_in_str = env::var("JWT_EXPIRES_IN").unwrap_or_else(|_| "3600".to_string());
|
||||
let expires_in = expires_in_str.parse::<i64>().unwrap_or(3600);
|
||||
let issuer = env::var("JWT_ISSUER").unwrap_or_default();
|
||||
|
||||
JWTOptions {
|
||||
algorithm: "RS256".to_string(),
|
||||
expires_in,
|
||||
issuer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct JWTUtils {
|
||||
payload: Value,
|
||||
private_key: String,
|
||||
public_key: String,
|
||||
options: JWTOptions,
|
||||
}
|
||||
|
||||
// --- Helper Functions ---
|
||||
|
||||
fn base64url_encode<T: AsRef<[u8]>>(input: T) -> String {
|
||||
URL_SAFE_NO_PAD.encode(input)
|
||||
}
|
||||
|
||||
fn base64url_decode(input: &str) -> Result<Vec<u8>, base64::DecodeError> {
|
||||
URL_SAFE_NO_PAD.decode(input)
|
||||
}
|
||||
|
||||
// --- Implementation ---
|
||||
|
||||
impl JWTUtils {
|
||||
pub fn new(
|
||||
payload: Value,
|
||||
private_key: Option<String>,
|
||||
public_key: Option<String>,
|
||||
options: Option<JWTOptions>,
|
||||
) -> Result<Self, JWTError> {
|
||||
let opts = options.unwrap_or_default();
|
||||
|
||||
let (priv_key, pub_key) = match (private_key, public_key) {
|
||||
(Some(priv_k), Some(pub_k)) => (priv_k, pub_k),
|
||||
(priv_k, pub_k) => {
|
||||
let keys = CryptoUtils::load_keys_from_files("keys")
|
||||
.map_err(|e| JWTError::Crypto(e.to_string()))?;
|
||||
(
|
||||
priv_k.unwrap_or(keys.private_key),
|
||||
pub_k.unwrap_or(keys.public_key),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(JWTUtils {
|
||||
payload,
|
||||
private_key: priv_key,
|
||||
public_key: pub_key,
|
||||
options: opts,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create JWT header
|
||||
fn create_header(&self) -> JWTHeader {
|
||||
JWTHeader {
|
||||
alg: self.options.algorithm.clone(),
|
||||
typ: "JWT".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process payload with standard claims
|
||||
fn process_payload(&self) -> Result<String, JWTError> {
|
||||
let mut payload_obj = self
|
||||
.payload
|
||||
.as_object()
|
||||
.ok_or_else(|| JWTError::Custom("Payload must be a JSON object".to_string()))?
|
||||
.clone();
|
||||
|
||||
let now = Utc::now();
|
||||
let iat = now.timestamp();
|
||||
let exp = (now + Duration::seconds(self.options.expires_in)).timestamp();
|
||||
|
||||
payload_obj.insert("iat".to_string(), iat.into());
|
||||
payload_obj.insert("exp".to_string(), exp.into());
|
||||
payload_obj.insert("iss".to_string(), self.options.issuer.clone().into());
|
||||
|
||||
Ok(to_string(&payload_obj)?)
|
||||
}
|
||||
|
||||
/// Sign the JWT components
|
||||
fn sign(&self, header_base64: &str, payload_base64: &str) -> Result<String, JWTError> {
|
||||
let signature_input = format!("{}.{}", header_base64, payload_base64);
|
||||
let keypair = Rsa::private_key_from_pem(self.private_key.as_bytes())?;
|
||||
let pkey = PKey::from_rsa(keypair)?;
|
||||
|
||||
let mut signer = Signer::new(MessageDigest::sha256(), &pkey)?;
|
||||
signer.update(signature_input.as_bytes())?;
|
||||
let signature = signer.sign_to_vec()?;
|
||||
|
||||
Ok(base64url_encode(&signature))
|
||||
}
|
||||
|
||||
/// Create complete JWT
|
||||
pub fn create_token(&self) -> Result<String, JWTError> {
|
||||
let header = self.create_header();
|
||||
let processed_payload = self.process_payload()?;
|
||||
|
||||
let header_base64 = base64url_encode(to_string(&header)?);
|
||||
let payload_base64 = base64url_encode(processed_payload);
|
||||
let signature = self.sign(&header_base64, &payload_base64)?;
|
||||
|
||||
Ok(format!(
|
||||
"{}.{}.{}",
|
||||
header_base64, payload_base64, signature
|
||||
))
|
||||
}
|
||||
|
||||
/// Verify JWT token signature
|
||||
pub fn verify(&self, token: &str) -> bool {
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let signature_input = format!("{}.{}", parts[0], parts[1]);
|
||||
let signature = match base64url_decode(parts[2]) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
let key = match PKey::public_key_from_pem(self.public_key.as_bytes()) {
|
||||
Ok(k) => k,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
let mut verifier = match Verifier::new(MessageDigest::sha256(), &key) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return false,
|
||||
};
|
||||
verifier.update(signature_input.as_bytes()).unwrap();
|
||||
|
||||
verifier.verify(&signature).unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Decode JWT token without verification
|
||||
pub fn decode(token: &str) -> Result<(Value, Value), JWTError> {
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() < 2 {
|
||||
return Err(JWTError::InvalidTokenFormat(
|
||||
"Token must have at least two parts".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let header_json = String::from_utf8(base64url_decode(parts[0])?)
|
||||
.map_err(|_| JWTError::InvalidTokenFormat("Header is not valid UTF-8".to_string()))?;
|
||||
let payload_json = String::from_utf8(base64url_decode(parts[1])?)
|
||||
.map_err(|_| JWTError::InvalidTokenFormat("Payload is not valid UTF-8".to_string()))?;
|
||||
|
||||
let header: Value = from_str(&header_json)?;
|
||||
let payload: Value = from_str(&payload_json)?;
|
||||
|
||||
Ok((header, payload))
|
||||
}
|
||||
|
||||
/// Check if token is expired
|
||||
pub fn is_expired(token: &str) -> bool {
|
||||
match Self::decode(token) {
|
||||
Ok((_, payload)) => {
|
||||
if let Some(exp) = payload.get("exp").and_then(|v| v.as_i64()) {
|
||||
let now = Utc::now().timestamp();
|
||||
exp < now
|
||||
} else {
|
||||
true // No expiration claim, consider it expired for safety
|
||||
}
|
||||
}
|
||||
Err(_) => true, // Invalid token, consider it expired
|
||||
}
|
||||
}
|
||||
|
||||
/// A combined decode and verify function
|
||||
pub fn decode_and_verify(token: &str) -> Result<(Value, Value), JWTError> {
|
||||
let jwt = JWTUtils::new(Value::Null, None, None, None)?;
|
||||
|
||||
if !jwt.verify(token) {
|
||||
return Err(JWTError::Validation(
|
||||
"Signature verification failed".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if Self::is_expired(token) {
|
||||
return Err(JWTError::Validation("Token has expired".to_string()));
|
||||
}
|
||||
|
||||
Self::decode(token)
|
||||
}
|
||||
|
||||
/// Refresh a token
|
||||
pub fn refresh_token(old_token: &str) -> Result<String, JWTError> {
|
||||
let (_, payload_val) = Self::decode(old_token)?;
|
||||
let mut payload_obj = payload_val
|
||||
.as_object()
|
||||
.ok_or_else(|| JWTError::Custom("Payload is not an object".to_string()))?
|
||||
.clone();
|
||||
|
||||
payload_obj.remove("exp");
|
||||
payload_obj.remove("iat");
|
||||
|
||||
// The keys must be loaded from files for this static method
|
||||
let keys = CryptoUtils::load_keys_from_files("keys")
|
||||
.map_err(|e| JWTError::Crypto(e.to_string()))?;
|
||||
|
||||
let jwt = JWTUtils::new(
|
||||
Value::Object(payload_obj),
|
||||
Some(keys.private_key),
|
||||
Some(keys.public_key),
|
||||
None,
|
||||
)?;
|
||||
|
||||
jwt.create_token()
|
||||
}
|
||||
|
||||
/// Validate that specific claims are present in the token
|
||||
pub fn validate_claims(token: &str, required_claims: &[&str]) -> bool {
|
||||
match Self::decode(token) {
|
||||
Ok((_, payload)) => {
|
||||
if let Some(payload_obj) = payload.as_object() {
|
||||
required_claims
|
||||
.iter()
|
||||
.all(|claim| payload_obj.contains_key(*claim))
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user