Compare commits

..

1 Commits

Author SHA1 Message Date
96b30b90cb first commit 2025-08-14 12:09:17 -04:00
28 changed files with 65 additions and 2851 deletions

2
.env
View File

@@ -2,4 +2,4 @@
RUST_LOG=info
BIND_ADDRESS=127.0.0.1:3000
MONGO_URI=mongodb://localhost:27017
# DATABASE_URL=postgres://gerard@localhost/db (not used yet)

1338
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -5,13 +5,7 @@ edition = "2024"
[dependencies]
axum = "0.8.4"
base64 = "0.22.1"
bson = { version = "2.15.0", features = ["chrono-0_4"] }
chrono = { version = "0.4.41", features = ["serde"] }
clap = { version = "4.5.45", features = ["derive"] }
dotenvy = "0.15.7"
mongodb = "3.2.4"
openssl = "0.10.73"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.142"
sqlx = { version = "0.8.6", features = ["runtime-tokio", "tls-native-tls"] }
@@ -19,6 +13,3 @@ tokio = { version = "1.47.1", features = ["full", "rt-multi-thread", "signal"] }
tower-http = { version = "0.6.6", features = ["trace"] }
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
[features]
no-auth = []

152
README.md
View File

@@ -1,152 +0,0 @@
# PureNotify Backend
This is the backend service for the PureNotify application, written in Rust. It's built to be a high-performance, reliable, and scalable foundation for sending notifications.
## 🚀 Features
- **Asynchronous:** Built with `tokio` and `axum` for non-blocking I/O and high concurrency.
- **Configurable:** Easily configure the application using environment variables.
- **Logging:** Integrated structured logging with `tracing` for better observability.
- **Graceful Shutdown:** Ensures the server shuts down cleanly without dropping active connections.
- **Health Check:** A dedicated endpoint to monitor the service's health.
## 🛠️ Technologies Used
- **[Rust](https://www.rust-lang.org/)**: The core programming language.
- **[Axum](https://github.com/tokio-rs/axum)**: A web application framework that focuses on ergonomics and modularity.
- **[Tokio](https://tokio.rs/)**: An asynchronous runtime for the Rust programming language.
- **[Serde](https://serde.rs/)**: A framework for serializing and deserializing Rust data structures efficiently.
- **[Dotenvy](https://github.com/dotenv-rs/dotenv)**: For loading environment variables from a `.env` file.
- **[Tracing](https://github.com/tokio-rs/tracing)**: A framework for instrumenting Rust programs to collect structured, event-based diagnostic information.
## ⚙️ Getting Started
Follow these instructions to get a copy of the project up and running on your local machine for development and testing purposes.
### Prerequisites
You need to have the Rust toolchain installed on your system. If you don't have it, you can install it from [rustup.rs](https://rustup.rs/).
```sh
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
```
### Installation & Running
1. **Clone the repository:**
```sh
git clone https://github.com/your-username/purenotify_backend.git
cd purenotify_backend
```
2. **Create a `.env` file:**
Copy the example environment file to create your own local configuration.
```sh
cp .env.example .env
```
You can modify the `.env` file to change the server's configuration.
3. **Build the project:**
```sh
cargo build
```
4. **Run the application:**
For development, you can run the project directly with `cargo run`:
```sh
cargo run
```
For a release build, run:
```sh
cargo run --release
```
The server will start, and you should see log output in your terminal indicating that it's running.
## 🔧 Configuration
The application is configured using environment variables. These can be set in a `.env` file in the project root or directly in your shell.
- `BIND_ADDRESS`: The IP address and port the server should listen on.
- **Default:** `127.0.0.1:3000`
- `RUST_LOG`: Controls the log level for the application.
- **Example:** `RUST_LOG=info,purenotify_backend=debug` will set the default log level to `info` and the log level for this crate to `debug`.
- **Default:** Reads from the environment; if not set, logging may be minimal.
## API Endpoints
Here are the available API endpoints for the service.
### Health Check
- **Endpoint:** `/health`
- **Method:** `GET`
- **Description:** Used to verify that the service is running and healthy.
- **Success Response:**
- **Code:** `200 OK`
- **Content:** `{
"message": "health check successful",
"data": {},
"success": true,
"error": false
}`
#### Example Usage
You can use `curl` to check the health of the service:
```sh
curl http://127.0.0.1:3000/health
```
**Expected Output:**
```json
{
"message": "health check successful",
"data": {},
"success": true,
"error": false
}
```
## 📂 Project Structure
The project follows a standard Rust project layout. The main application logic is located in the `src/` directory.
```
src/
├── main.rs # Application entry point, server setup
├── config.rs # Configuration management
├── handlers/ # Business logic for handling requests
│ └── health/
│ └── health.rs
├── routes/ # API route definitions
│ └── health/
│ └── health.rs
└── utils/ # Utility functions and shared modules
```
- `main.rs`: Initializes the server, logging, configuration, and wires up the routes.
- `config.rs`: Defines the `Config` struct and handles loading configuration from the environment.
- `handlers/`: Contains the core logic for each API endpoint. Each handler is responsible for processing a request and returning a response.
- `routes/`: Defines the Axum `Router` for different parts of the application. These modules map URL paths to their corresponding handlers.
- `utils/`: A place for helper functions or modules that are used across different parts of the application.
## 🤝 Contributing
Contributions are welcome! If you'd like to contribute, please fork the repository and use a feature branch. Pull requests are warmly welcome.
1. Fork the repository.
2. Create your feature branch (`git checkout -b feature/fooBar`).
3. Commit your changes (`git commit -am 'Add some fooBar'`).
4. Push to the branch (`git push origin feature/fooBar`).
5. Create a new Pull Request.
## 📄 License
This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details.

View File

@@ -4,19 +4,19 @@ use std::env;
use std::net::SocketAddr;
use std::str::FromStr;
#[cfg(feature = "no-auth")]
use tracing::error;
#[derive(Debug)]
pub struct Config {
pub bind_address: SocketAddr,
pub mongodb_uri: String,
pub database_url: Option<String>,
}
impl Config {
pub fn from_env() -> Result<Self, String> {
let bind_address_str =
env::var("BIND_ADDRESS").unwrap_or_else(|_| "127.0.0.1:3000".to_string());
let bind_address = SocketAddr::from_str(&bind_address_str)
.map_err(|e| format!("Invalid BIND_ADDRESS: {}", e))?;
@@ -26,12 +26,11 @@ impl Config {
return Err("In no-auth mode, BIND_ADDRESS must be 127.0.0.1".to_string());
}
let mongodb_uri =
env::var("MONGODB_URI").unwrap_or_else(|_| "mongodb://localhost:27017".to_string());
let database_url = env::var("DATABASE_URL").ok();
Ok(Self {
bind_address,
mongodb_uri,
database_url,
})
}
}

View File

@@ -1,103 +0,0 @@
// src/db/migrations.rs
use mongodb::bson::doc;
use mongodb::{
Client, Database, IndexModel,
options::{ClientOptions, IndexOptions},
};
use tracing::info;
pub struct Migrator {
db: Database,
}
impl Migrator {
pub async fn new(uri: &str) -> Result<Self, mongodb::error::Error> {
let client_options = ClientOptions::parse(uri).await?;
let client = Client::with_options(client_options)?;
let db = client.database("employee_tracking");
Ok(Self { db })
}
pub async fn run(&self) -> Result<(), mongodb::error::Error> {
info!("Running migrations...");
// Users collection
self.db.create_collection("users").await?;
self.db
.collection::<bson::Document>("users")
.create_index(
IndexModel::builder()
.keys(doc! { "email": 1 })
.options(IndexOptions::builder().sparse(true).unique(true).build())
.build(),
)
.await?;
// Employees collection
self.db.create_collection("employees").await?;
self.db
.collection::<bson::Document>("employees")
.create_index(
IndexModel::builder()
.keys(doc! { "email": 1 })
.options(IndexOptions::builder().sparse(true).unique(true).build())
.build(),
)
.await?;
// Punches collection
self.db.create_collection("punches").await?;
self.db
.collection::<bson::Document>("punches")
.create_index(
IndexModel::builder()
.keys(doc! { "employee_id": 1, "clock_out_at": 1 })
.options(
IndexOptions::builder()
.partial_filter_expression(doc! { "clock_out_at": null })
.unique(true)
.build(),
)
.build(),
)
.await?;
// Shifts collection
self.db.create_collection("shifts").await?;
self.db
.collection::<bson::Document>("shifts")
.create_index(
IndexModel::builder()
.keys(doc! { "employee_id": 1, "start_at": 1, "end_at": 1 })
.build(),
)
.await?;
// Leave requests collection
self.db.create_collection("leave_requests").await?;
self.db
.collection::<bson::Document>("leave_requests")
.create_index(
IndexModel::builder()
.keys(doc! { "employee_id": 1, "start_date": 1, "end_date": 1, "status": 1 })
.build(),
)
.await?;
// Inventory items collection (optional)
self.db.create_collection("inventory_items").await?;
self.db
.collection::<bson::Document>("inventory_items")
.create_index(
IndexModel::builder()
.keys(doc! { "sku": 1 })
.options(IndexOptions::builder().unique(true).build())
.build(),
)
.await?;
info!("Migrations completed.");
Ok(())
}
}

View File

@@ -1,4 +0,0 @@
// src/db/mod.rs
pub mod migrations;
pub mod seed;

View File

@@ -1,138 +0,0 @@
// src/db/seed.rs
use chrono::Utc;
use mongodb::{
Client, Database,
bson::{DateTime, doc, oid::ObjectId},
};
use tracing::info;
pub struct Seeder {
db: Database,
}
impl Seeder {
pub async fn new(uri: &str) -> Result<Self, mongodb::error::Error> {
let client = Client::with_uri_str(uri).await?;
let db = client.database("employee_tracking");
Ok(Self { db })
}
pub async fn run(&self) -> Result<(), mongodb::error::Error> {
info!("Seeding database...");
// Clear collections
self.db
.collection::<bson::Document>("users")
.delete_many(doc! {})
.await?;
self.db
.collection::<bson::Document>("employees")
.delete_many(doc! {})
.await?;
self.db
.collection::<bson::Document>("punches")
.delete_many(doc! {})
.await?;
self.db
.collection::<bson::Document>("shifts")
.delete_many(doc! {})
.await?;
self.db
.collection::<bson::Document>("leave_requests")
.delete_many(doc! {})
.await?;
// Seed users
let manager_id = ObjectId::new();
self.db
.collection("users")
.insert_one(doc! {
"_id": manager_id.clone(),
"role": "manager",
"email": "manager@example.com"
})
.await?;
// Seed employees
let emp1_id = ObjectId::new();
let emp2_id = ObjectId::new();
self.db
.collection("employees")
.insert_many(vec![
doc! {
"_id": emp1_id.clone(),
"full_name": "John Doe",
"email": "john.doe@example.com",
"position": "Developer",
"active": true,
"created_at": DateTime::from_millis(Utc::now().timestamp_millis())
},
doc! {
"_id": emp2_id.clone(),
"full_name": "Jane Smith",
"email": "jane.smith@example.com",
"position": "Designer",
"active": true,
"created_at": DateTime::from_millis(Utc::now().timestamp_millis())
},
])
.await?;
// Seed punches
self.db
.collection("punches")
.insert_one(doc! {
"_id": ObjectId::new(),
"employee_id": emp1_id.clone(),
"clock_in_at": DateTime::from_millis(Utc::now().timestamp_millis()),
"clock_out_at": null,
"source": "web",
"created_at": DateTime::from_millis(Utc::now().timestamp_millis())
})
.await?;
// Seed shifts
self.db
.collection("shifts")
.insert_one(doc! {
"_id": ObjectId::new(),
"employee_id": emp1_id.clone(),
"start_at": DateTime::from_millis(Utc::now().timestamp_millis()),
"end_at": DateTime::from_millis(Utc::now().timestamp_millis() + 8 * 3600 * 1000),
"created_by": manager_id.clone(),
"notes": "Morning shift",
"created_at": DateTime::from_millis(Utc::now().timestamp_millis())
})
.await?;
// Seed leave requests
self.db.collection("leave_requests").insert_many(vec![
doc! {
"_id": ObjectId::new(),
"employee_id": emp1_id.clone(),
"start_date": DateTime::from_millis(Utc::now().timestamp_millis() + 2 * 24 * 3600 * 1000),
"end_date": DateTime::from_millis(Utc::now().timestamp_millis() + 4 * 24 * 3600 * 1000),
"status": "approved",
"reason": "Vacation",
"reviewed_by": manager_id.clone(),
"reviewed_at": DateTime::from_millis(Utc::now().timestamp_millis()),
"created_at": DateTime::from_millis(Utc::now().timestamp_millis())
},
doc! {
"_id": ObjectId::new(),
"employee_id": emp2_id.clone(),
"start_date": DateTime::from_millis(Utc::now().timestamp_millis() + 5 * 24 * 3600 * 1000),
"end_date": DateTime::from_millis(Utc::now().timestamp_millis() + 6 * 24 * 3600 * 1000),
"status": "pending",
"reason": "Medical",
"reviewed_by": null,
"reviewed_at": null,
"created_at": DateTime::from_millis(Utc::now().timestamp_millis())
},
]).await?;
info!("Seeding completed.");
Ok(())
}
}

View File

@@ -1,20 +0,0 @@
// src/handlers/health/health.rs
use axum::Json;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use serde_json::json;
pub async fn health() -> impl IntoResponse {
(
StatusCode::OK,
Json(json!(
{
"message": "health check successful",
"data": {},
"success": true,
"error": false,
}
)),
)
}

View File

@@ -1,3 +0,0 @@
// src/handlers/health/mod.rs
pub mod health;

View File

@@ -1,5 +0,0 @@
// src/handlers/mod.rs
pub mod health;
pub mod shift;
pub mod user;

View File

@@ -1,3 +0,0 @@
// src/handlers/shift/mod.rs
pub mod shift;

View File

@@ -1,80 +0,0 @@
// src/handlers/shift/shift.rs
use axum::{Extension, Json, http::StatusCode};
use chrono::Utc;
use mongodb::{
Database,
bson::{DateTime, doc},
};
use serde_json::json;
pub async fn create_shift(
Extension(db): Extension<Database>,
Json(payload): Json<serde_json::Value>,
) -> impl axum::response::IntoResponse {
let employee_id = payload
.get("employee_id")
.and_then(|v| v.as_str())
.unwrap_or("");
let start_at = payload
.get("start_at")
.and_then(|v| v.as_i64())
.unwrap_or(0);
let end_at = payload.get("end_at").and_then(|v| v.as_i64()).unwrap_or(0);
// Validate no overlapping shifts
let shifts = db.collection::<bson::Document>("shifts");
let overlap = shifts
.find_one(doc! {
"employee_id": employee_id,
"$or": [
{ "start_at": { "$lte": DateTime::from_millis(end_at) } },
{ "end_at": { "$gte": DateTime::from_millis(start_at) } },
]
})
.await
.unwrap();
if overlap.is_some() {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"message": "Shift overlaps with existing shift",
"success": false,
"error": true
})),
);
}
// Insert shift
let result = shifts
.insert_one(doc! {
"_id": bson::oid::ObjectId::new(),
"employee_id": employee_id,
"start_at": DateTime::from_millis(start_at),
"end_at": DateTime::from_millis(end_at),
"created_by": null,
"notes": payload.get("notes").and_then(|v| v.as_str()).unwrap_or(""),
"created_at": DateTime::from_millis(Utc::now().timestamp_millis())
})
.await;
match result {
Ok(_) => (
StatusCode::CREATED,
Json(json!({
"message": "Shift created successfully",
"success": true,
"error": false
})),
),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"message": format!("Failed to create shift: {}", e),
"success": false,
"error": true
})),
),
}
}

View File

@@ -1,3 +0,0 @@
// src/handlers/user/mod.rs
pub mod user;

View File

@@ -1,48 +0,0 @@
// src/handlers/user/user.rs
use axum::Json;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use serde_json::json;
pub async fn user() -> impl IntoResponse {
(
StatusCode::OK,
Json(json!(
{
"message": "health check successful",
"data": {
"id": "usr_123456789",
"username": "john_doe",
"email": "john.doe@example.com",
"first_name": "John",
"last_name": "Doe",
"role": "user",
"is_active": true,
"created_at": "2024-01-15T10:30:00Z",
"updated_at": "2024-08-15T14:20:00Z",
"profile": {
"avatar_url": "https://api.example.com/avatars/john_doe.png",
"bio": "Software developer passionate about Rust",
"location": "San Francisco, CA",
"website": "https://johndoe.dev"
},
"preferences": {
"theme": "dark",
"language": "en",
"notifications_enabled": true,
"email_verified": true
},
"stats": {
"total_posts": 42,
"total_comments": 156,
"total_likes": 523,
"account_age_days": 213
}
},
"success": true,
"error": false,
}
)),
)
}

10
src/health.rs Normal file
View File

@@ -0,0 +1,10 @@
// src/health.rs
use axum::Json;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use serde_json::json;
pub async fn health() -> impl IntoResponse {
(StatusCode::OK, Json(json!({ "status": "ok" })))
}

View File

@@ -1,33 +1,19 @@
// src/main.rs
use std::net::SocketAddr;
use std::process::exit;
use axum::Router;
use clap::{Parser, Subcommand};
use axum::{Router, routing::get};
use dotenvy::dotenv;
use mongodb::Client;
use tokio::signal;
use tower_http::trace::TraceLayer;
use tracing::{error, info};
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
mod config;
use crate::config::Config;
mod db; // Updated to import db module instead of migrations and seed
mod handlers;
mod routes;
mod health;
#[derive(Parser)]
struct Cli {
#[command(subcommand)]
command: Option<Commands>,
}
#[derive(Subcommand)]
enum Commands {
Migrate,
Seed,
}
use config::Config;
#[tokio::main]
async fn main() {
@@ -49,50 +35,6 @@ async fn main() {
}
};
// Initialize MongoDB client
let client = match Client::with_uri_str(&config.mongodb_uri).await {
Ok(client) => client,
Err(e) => {
error!("Failed to initialize MongoDB client: {}", e);
exit(1);
}
};
let db = client.database("employee_tracking");
// Handle CLI commands
let cli = Cli::parse();
match cli.command {
Some(Commands::Migrate) => {
let migrator = match db::migrations::Migrator::new(&config.mongodb_uri).await {
Ok(m) => m,
Err(e) => {
error!("Failed to initialize migrator: {}", e);
exit(1);
}
};
migrator.run().await.unwrap_or_else(|e| {
error!("Failed to run migrations: {}", e);
exit(1);
});
return;
}
Some(Commands::Seed) => {
let seeder = match db::seed::Seeder::new(&config.mongodb_uri).await {
Ok(s) => s,
Err(e) => {
error!("Failed to initialize seeder: {}", e);
exit(1);
}
};
seeder.run().await.unwrap_or_else(|e| {
error!("Failed to run seed: {}", e);
exit(1);
});
return;
}
None => {}
}
#[cfg(feature = "no-auth")]
info!("NO-AUTH MODE ENABLED");
@@ -100,10 +42,7 @@ async fn main() {
// Build the Axum router
let app = Router::new()
.nest("/health", routes::health::health::health_routes())
.nest("/user", routes::user::user::user_routes())
.nest("/shift", routes::shift::shift::shift_routes())
.layer(axum::Extension(db)) // Pass MongoDB database to handlers
.route("/health", get(health::health))
.layer(TraceLayer::new_for_http());
// Run the server

View File

@@ -1,9 +0,0 @@
// src/routes/health/health.rs
use axum::{Router, routing::get};
use crate::handlers::health::health::health;
pub fn health_routes() -> Router {
Router::new().route("/", get(health))
}

View File

@@ -1,4 +0,0 @@
// src/routes/health/mod.rs
pub mod health;

View File

@@ -1,5 +0,0 @@
// src/routes/mod.rs
pub mod health;
pub mod shift;
pub mod user;

View File

@@ -1,3 +0,0 @@
// src/routes/shift/mod.rs
pub mod shift;

View File

@@ -1,9 +0,0 @@
// src/routes/shift/shift.rs
use axum::{Router, routing::post};
use crate::handlers::shift::shift::create_shift;
pub fn shift_routes() -> Router {
Router::new().route("/", post(create_shift))
}

View File

@@ -1,3 +0,0 @@
// src/routes/user/mod.rs
pub mod user;

View File

@@ -1,9 +0,0 @@
// src/routes/user/user.rs
use axum::{Router, routing::get};
use crate::handlers::user::user::user;
pub fn user_routes() -> Router {
Router::new().route("/", get(user))
}

View File

@@ -1,203 +0,0 @@
# Crypto Utility Module (`crypto.rs`)
This document provides detailed documentation for the `CryptoUtils` module, a Rust implementation for essential cryptographic operations. The module offers functionalities for symmetric encryption/decryption and asymmetric key pair management.
## Table of Contents
1. [Overview](#overview)
2. [Dependencies](#dependencies)
3. [Error Handling](#error-handling)
4. [Core Structures](#core-structures)
- [`KeyPair`](#keypair)
5. [Static Properties](#static-properties)
- [`KEY`](#key)
- [`IV`](#iv)
6. [API Functions](#api-functions)
- [Symmetric Encryption](#symmetric-encryption)
- [`encrypt`](#encrypt)
- [`decrypt`](#decrypt)
- [Asymmetric Key Management](#asymmetric-key-management)
- [`generate_key_pair`](#generate_key_pair)
- [`save_keys_to_files`](#save_keys_to_files)
- [`load_keys_from_files`](#load_keys_from_files)
- [`init_keys`](#init_keys)
7. [Usage Examples](#usage-examples)
---
### Overview
The `CryptoUtils` module provides a set of static methods to perform common cryptographic tasks. It is designed to be a centralized utility for handling both symmetric (AES-256-CBC) and asymmetric (RSA-4096) cryptography.
### Dependencies
This module requires the following dependencies to be added to your `Cargo.toml`:
```toml
[dependencies]
openssl = "0.10"
sha2 = "0.10"
hex = "0.4"
once_cell = "1.19" # For lazy static initialization
```
### Error Handling
The module defines a custom `CryptoError` enum to handle various failure scenarios, providing clear and specific error information.
- `OpenSsl`: Wraps errors from the `openssl` crate.
- `Io`: For file system I/O errors (e.g., reading/writing keys).
- `Hex`: For errors during hex encoding/decoding.
- `Utf8`: For errors converting byte slices to UTF-8 strings.
- `Custom`: For other specific, custom error messages.
### Core Structures
#### `KeyPair`
A public struct that holds a pair of RSA keys.
- `private_key: String`: The PEM-encoded private key.
- `public_key: String`: The PEM-encoded public key.
### Static Properties
#### `KEY`
A statically initialized 32-byte array used as the secret key for AES-256-CBC encryption and decryption. It is derived by applying SHA-256 to a hardcoded salt phrase, ensuring a consistent key across the application.
#### `IV`
A 16-byte initialization vector used for the AES-256-CBC algorithm.
### API Functions
All functions are implemented as static methods on the `CryptoUtils` struct.
#### Symmetric Encryption
##### `encrypt`
`pub fn encrypt(secret: &str) -> Result<String, CryptoError>`
Encrypts a string slice using AES-256-CBC.
- **Parameters**:
- `secret`: The plaintext string to encrypt.
- **Returns**: A `Result` containing the hex-encoded ciphertext string or a `CryptoError`.
##### `decrypt`
`pub fn decrypt(encrypted_secret: &str) -> Result<String, CryptoError>`
Decrypts a hex-encoded ciphertext string using AES-256-CBC.
- **Parameters**:
- `encrypted_secret`: The hex-encoded ciphertext.
- **Returns**: A `Result` containing the decrypted plaintext string or a `CryptoError`.
#### Asymmetric Key Management
##### `generate_key_pair`
`pub fn generate_key_pair() -> Result<KeyPair, CryptoError>`
Generates a new 4096-bit RSA key pair.
- **Returns**: A `Result` containing a `KeyPair` struct with the new PEM-encoded keys or a `CryptoError`.
##### `save_keys_to_files`
`pub fn save_keys_to_files(keys: &KeyPair, directory: &Path) -> Result<(), CryptoError>`
Saves a `KeyPair` to the specified directory in two files: `private.pem` and `public.pem`.
- **Parameters**:
- `keys`: A reference to the `KeyPair` to save.
- `directory`: The path to the directory where the keys will be saved.
##### `load_keys_from_files`
`pub fn load_keys_from_files(directory: &Path) -> Result<KeyPair, CryptoError>`
Loads an RSA key pair from `private.pem` and `public.pem` files in a given directory.
- **Parameters**:
- `directory`: The path to the directory containing the key files.
- **Returns**: A `Result` containing the loaded `KeyPair` or a `CryptoError`.
##### `init_keys`
`pub fn init_keys() -> Result<KeyPair, CryptoError>`
A convenience function that initializes the RSA key pair for the application. It first checks if the keys exist in the default `./keys` directory.
- If the keys exist, it loads them.
- If they do not exist, it generates a new pair and saves them to the `./keys` directory.
- **Returns**: A `Result` containing the initialized `KeyPair` or a `CryptoError`.
### Usage Examples
#### Example 1: AES Encryption and Decryption
```rust
use your_project::utils::crypto::crypto::CryptoUtils;
fn main() {
let secret_message = "This is a highly confidential message.";
// Encrypt the message
match CryptoUtils::encrypt(secret_message) {
Ok(encrypted) => {
println!("Original: {}", secret_message);
println!("Encrypted: {}", encrypted);
// Decrypt the message
match CryptoUtils::decrypt(&encrypted) {
Ok(decrypted) => {
println!("Decrypted: {}", decrypted);
assert_eq!(secret_message, decrypted);
},
Err(e) => eprintln!("Decryption failed: {}", e),
}
},
Err(e) => eprintln!("Encryption failed: {}", e),
}
}
```
#### Example 2: RSA Key Pair Initialization and Management
This example demonstrates how to ensure RSA keys are available for the application.
```rust
use your_project::utils::crypto::crypto::CryptoUtils;
use std::fs;
use std::path::Path;
fn main() {
// Clean up previous keys for demonstration purposes
if Path::new("keys").exists() {
fs::remove_dir_all("keys").unwrap();
}
println!("Attempting to initialize keys...");
// Use init_keys to either generate or load keys
match CryptoUtils::init_keys() {
Ok(keys) => {
println!("Keys initialized successfully.");
println!("Public Key (first 50 chars): {}...", &keys.public_key[..50]);
// Calling it again should now load the existing keys
println!("\nCalling init_keys again...");
let loaded_keys = CryptoUtils::init_keys().unwrap();
assert_eq!(keys.public_key, loaded_keys.public_key);
println!("Keys loaded successfully from files.");
},
Err(e) => {
eprintln!("Failed to initialize keys: {}", e);
}
}
}
```

View File

@@ -1,167 +0,0 @@
use once_cell::sync::Lazy;
use openssl::error::ErrorStack;
use openssl::pkey::PKey;
use openssl::rsa::{Padding, Rsa};
use openssl::symm::{Cipher, Crypter, Mode};
use sha2::{Digest, Sha256};
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
// --- Error Handling ---
#[derive(Debug)]
pub enum CryptoError {
OpenSsl(ErrorStack),
Io(io::Error),
Hex(hex::FromHexError),
Utf8(std::string::FromUtf8Error),
Custom(String),
}
impl From<ErrorStack> for CryptoError {
fn from(err: ErrorStack) -> Self {
CryptoError::OpenSsl(err)
}
}
impl From<io::Error> for CryptoError {
fn from(err: io::Error) -> Self {
CryptoError::Io(err)
}
}
impl From<hex::FromHexError> for CryptoError {
fn from(err: hex::FromHexError) -> Self {
CryptoError::Hex(err)
}
}
impl From<std::string::FromUtf8Error> for CryptoError {
fn from(err: std::string::FromUtf8Error) -> Self {
CryptoError::Utf8(err)
}
}
impl std::fmt::Display for CryptoError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
CryptoError::OpenSsl(e) => write!(f, "OpenSSL error: {}", e),
CryptoError::Io(e) => write!(f, "IO error: {}", e),
CryptoError::Hex(e) => write!(f, "Hex decoding error: {}", e),
CryptoError::Utf8(e) => write!(f, "UTF-8 conversion error: {}", e),
CryptoError::Custom(s) => write!(f, "Crypto error: {}", s),
}
}
}
impl std::error::Error for CryptoError {}
// --- KeyPair Structure ---
#[derive(Debug, Clone)]
pub struct KeyPair {
pub private_key: String,
pub public_key: String,
}
// --- Crypto Utility ---
pub struct CryptoUtils;
static KEY: Lazy<[u8; 32]> = Lazy::new(|| {
let mut hasher = Sha256::new();
hasher.update(b"z&R*3mN@wS5gY!8c*P#L5bQm&8wT3vNxE!UW4ex7HJKLfghRT");
hasher.finalize().into()
});
const IV: &[u8; 16] = b"6234567890123456";
impl CryptoUtils {
/// Encrypts a string using AES-256-CBC.
pub fn encrypt(secret: &str) -> Result<String, CryptoError> {
let cipher = Cipher::aes_256_cbc();
let data = secret.as_bytes();
let mut encrypter = Crypter::new(cipher, Mode::Encrypt, &KEY, Some(IV))?;
encrypter.pad(true);
let mut encrypted = vec![0; data.len() + cipher.block_size()];
let count = encrypter.update(data, &mut encrypted)?;
let rest = encrypter.finalize(&mut encrypted[count..])?;
encrypted.truncate(count + rest);
Ok(hex::encode(encrypted))
}
/// Decrypts a string using AES-256-CBC.
pub fn decrypt(encrypted_secret: &str) -> Result<String, CryptoError> {
let cipher = Cipher::aes_256_cbc();
let data = hex::decode(encrypted_secret)?;
let mut decrypter = Crypter::new(cipher, Mode::Decrypt, &KEY, Some(IV))?;
decrypter.pad(true);
let mut decrypted = vec![0; data.len() + cipher.block_size()];
let count = decrypter.update(&data, &mut decrypted)?;
let rest = decrypter.finalize(&mut decrypted[count..])?;
decrypted.truncate(count + rest);
Ok(String::from_utf8(decrypted)?)
}
/// Generates a 4096-bit RSA key pair in PEM format.
pub fn generate_key_pair() -> Result<KeyPair, CryptoError> {
let rsa = Rsa::generate(4096)?;
let pkey = PKey::from_rsa(rsa)?;
let private_key = pkey
.private_key_to_pem_pkcs8()?
.iter()
.map(|&c| c as char)
.collect();
let public_key = pkey
.public_key_to_pem()?
.iter()
.map(|&c| c as char)
.collect();
Ok(KeyPair {
private_key,
public_key,
})
}
/// Saves the given KeyPair to files in the specified directory.
pub fn save_keys_to_files(keys: &KeyPair, directory: &Path) -> Result<(), CryptoError> {
fs::create_dir_all(directory)?;
fs::write(directory.join("private.pem"), &keys.private_key)?;
fs::write(directory.join("public.pem"), &keys.public_key)?;
Ok(())
}
/// Loads a KeyPair from files in the specified directory.
pub fn load_keys_from_files(directory: &Path) -> Result<KeyPair, CryptoError> {
let private_key = fs::read_to_string(directory.join("private.pem"))?;
let public_key = fs::read_to_string(directory.join("public.pem"))?;
Ok(KeyPair {
private_key,
public_key,
})
}
/// Initializes RSA key pair.
/// If keys exist in the default 'keys' directory, they are loaded.
/// Otherwise, new keys are generated and saved.
pub fn init_keys() -> Result<KeyPair, CryptoError> {
let key_path = PathBuf::from("keys");
if key_path.join("private.pem").exists() && key_path.join("public.pem").exists() {
Self::load_keys_from_files(&key_path)
} else {
let keys = Self::generate_key_pair()?;
Self::save_keys_to_files(&keys, &key_path)?;
Ok(keys)
}
}
}

View File

@@ -1,186 +0,0 @@
# JWT Utility Module Documentation
## Overview
The `JWTUtils` module provides a comprehensive suite of tools for creating, signing, verifying, and managing JSON Web Tokens (JWTs) in Rust. It is designed to work seamlessly with the `CryptoUtils` module for RSA key management. This implementation handles JWTs manually using the `openssl` crate for cryptographic operations, avoiding external JWT-specific libraries.
The module supports standard claims like `exp` (expiration time), `iat` (issued at), and `iss` (issuer), and allows for custom payloads.
## Dependencies
To use this module, ensure the following dependencies are included in your `Cargo.toml` file:
```toml
[dependencies]
openssl = "0.10"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
chrono = "0.4"
base64 = "0.21"
```
## Core Components
### `JWTError` Enum
A custom error type that consolidates all potential failures within the module, including issues from `openssl`, `serde_json`, `base64`, and invalid token logic.
- `OpenSsl(openssl::error::ErrorStack)`: An error from the OpenSSL library.
- `SerdeJson(serde_json::Error)`: An error during JSON serialization or deserialization.
- `Base64(base64::DecodeError)`: An error during Base64 decoding.
- `Crypto(String)`: An error related to cryptographic key loading from `CryptoUtils`.
- `InvalidTokenFormat(String)`: The token string is malformed (e.g., wrong number of segments).
- `Validation(String)`: The token failed a validation check (e.g., signature invalid, expired).
- `Custom(String)`: A generic error for other specific issues.
### `JWTOptions` Struct
Defines the configurable options for creating a JWT.
- `algorithm: String`: The signing algorithm (defaults to `"RS256"`).
- `expires_in: i64`: The token's lifetime in seconds. Defaults to `3600` (1 hour) or the value of the `JWT_EXPIRES_IN` environment variable.
- `issuer: String`: The issuer of the token. Defaults to an empty string or the value of the `JWT_ISSUER` environment variable.
### `JWTUtils` Struct
The main struct for handling JWT operations. An instance of `JWTUtils` is typically created to generate a new token.
- `payload: Value`: The custom payload for the JWT, represented as a `serde_json::Value`.
- `private_key: String`: The PEM-encoded RSA private key for signing.
- `public_key: String`: The PEM-encoded RSA public key for verification.
- `options: JWTOptions`: The configuration options for the token.
## Instantiation
### `new(payload: Value, private_key: Option<String>, public_key: Option<String>, options: Option<JWTOptions>) -> Result<Self, JWTError>`
Creates a new `JWTUtils` instance.
- **payload**: The custom data to include in the token.
- **private_key / public_key**: Optional RSA keys. If not provided, the constructor will attempt to load them from the default `keys/` directory using `CryptoUtils::load_keys_from_files()`.
- **options**: Optional `JWTOptions`. If not provided, default values will be used.
## Instance Methods
### `create_token(&self) -> Result<String, JWTError>`
Generates and signs a complete JWT string. The process involves:
1. Creating the JWT header (`{"alg": "RS256", "typ": "JWT"}`).
2. Processing the payload by adding standard claims (`iat`, `exp`, `iss`).
3. Base64URL-encoding the header and payload.
4. Creating a signature by signing the encoded header and payload with the RSA private key.
5. Combining the three parts into the final `header.payload.signature` format.
### `verify(&self, token: &str) -> bool`
Verifies the signature of a given JWT using the public key stored in the `JWTUtils` instance.
- **Returns**: `true` if the signature is valid, `false` otherwise.
- **Note**: This method **only** checks the signature. It does not validate the expiration time or other claims. For comprehensive validation, use `decode_and_verify`.
## Static Methods
### `decode(token: &str) -> Result<(Value, Value), JWTError>`
Decodes a JWT string into its header and payload components without verifying the signature.
- **Returns**: A tuple `(header, payload)` where both elements are of type `serde_json::Value`.
- **Use Case**: Useful for inspecting token data when the signature's validity is not a concern (e.g., for logging or preliminary checks).
### `is_expired(token: &str) -> bool`
Checks if a token has expired. It decodes the token and compares the `exp` claim to the current UTC time.
- **Returns**: `true` if the token is expired, has no `exp` claim, or is malformed. `false` otherwise.
### `decode_and_verify(token: &str) -> Result<(Value, Value), JWTError>`
A comprehensive function that performs all necessary validations on a token:
1. Verifies the token's signature using the public key loaded from the default `keys/` directory.
2. Checks if the token has expired.
3. If both checks pass, it decodes and returns the header and payload.
- **Returns**: `Ok((header, payload))` on success, or a `JWTError` if validation fails for any reason.
### `refresh_token(old_token: &str) -> Result<String, JWTError>`
Generates a new token based on the payload of an old token. The `iat` and `exp` claims are stripped from the original payload and replaced with new ones.
- **Note**: This function loads the RSA keys from the `keys/` directory to sign the new token.
### `validate_claims(token: &str, required_claims: &[&str]) -> bool`
Checks for the presence of specific keys in the token's payload.
- **required_claims**: A slice of strings representing the keys that must be present.
- **Returns**: `true` if all required claims exist, `false` otherwise.
## Usage Examples
### Example 1: Creating a JWT
```rust
use serde_json::json;
use your_project::utils::jwt::JWTUtils;
fn create_new_token() {
// The custom data for the token
let payload = json!({
"user_id": "12345",
"roles": ["admin", "user"]
});
// Keys can be provided directly or loaded automatically from the 'keys/' directory
// If None, the constructor will try to load them from files.
let jwt_instance = JWTUtils::new(payload, None, None, None).unwrap();
match jwt_instance.create_token() {
Ok(token) => println!("Generated Token: {}", token),
Err(e) => eprintln!("Error creating token: {}", e),
}
}
```
### Example 2: Verifying and Decoding a JWT
```rust
use your_project::utils::jwt::JWTUtils;
fn validate_and_read_token(token: &str) {
match JWTUtils::decode_and_verify(token) {
Ok((header, payload)) => {
println!("Token is valid!");
println!("Header: {:?}", header);
println!("Payload: {:?}", payload);
},
Err(e) => {
eprintln!("Token validation failed: {}", e);
}
}
// You can also check for specific claims
if JWTUtils::validate_claims(token, &["user_id", "roles"]) {
println!("All required claims are present.");
}
}
```
### Example 3: Refreshing a Token
```rust
use your_project::utils::jwt::JWTUtils;
fn refresh_existing_token(old_token: &str) {
match JWTUtils::refresh_token(old_token) {
Ok(new_token) => {
println!("Token refreshed successfully!");
println!("New Token: {}", new_token);
},
Err(e) => {
eprintln!("Failed to refresh token: {}", e);
}
}
}
```

View File

@@ -1,320 +0,0 @@
// Add the following dependencies to your Cargo.toml file:
// openssl = "0.10"
// serde = { version = "1.0", features = ["derive"] }
// serde_json = "1.0"
// chrono = "0.4"
// base64 = "0.21"
use crate::utils::crypto::crypto::CryptoUtils;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::{Duration, Utc};
use openssl::{
hash::MessageDigest,
pkey::{PKey, Private},
rsa::Rsa,
sign::{Signer, Verifier},
};
use serde::{Deserialize, Serialize};
use serde_json::{Value, from_str, from_value, to_string};
use std::collections::BTreeMap;
use std::env;
// --- Error Handling ---
#[derive(Debug)]
pub enum JWTError {
OpenSsl(openssl::error::ErrorStack),
SerdeJson(serde_json::Error),
Base64(base64::DecodeError),
Crypto(String),
InvalidTokenFormat(String),
Validation(String),
Custom(String),
}
impl std::fmt::Display for JWTError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
JWTError::OpenSsl(e) => write!(f, "OpenSSL error: {}", e),
JWTError::SerdeJson(e) => write!(f, "JSON serialization error: {}", e),
JWTError::Base64(e) => write!(f, "Base64 decoding error: {}", e),
JWTError::Crypto(s) => write!(f, "Crypto error: {}", s),
JWTError::InvalidTokenFormat(s) => write!(f, "Invalid token format: {}", s),
JWTError::Validation(s) => write!(f, "Token validation failed: {}", s),
JWTError::Custom(s) => write!(f, "JWT error: {}", s),
}
}
}
impl std::error::Error for JWTError {}
impl From<openssl::error::ErrorStack> for JWTError {
fn from(err: openssl::error::ErrorStack) -> JWTError {
JWTError::OpenSsl(err)
}
}
impl From<serde_json::Error> for JWTError {
fn from(err: serde_json::Error) -> JWTError {
JWTError::SerdeJson(err)
}
}
impl From<base64::DecodeError> for JWTError {
fn from(err: base64::DecodeError) -> JWTError {
JWTError::Base64(err)
}
}
// --- Structures ---
#[derive(Debug, Serialize, Deserialize, Clone)]
struct JWTHeader {
alg: String,
typ: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct JWTOptions {
pub algorithm: String,
pub expires_in: i64, // seconds
pub issuer: String,
}
impl Default for JWTOptions {
fn default() -> Self {
let expires_in_str = env::var("JWT_EXPIRES_IN").unwrap_or_else(|_| "3600".to_string());
let expires_in = expires_in_str.parse::<i64>().unwrap_or(3600);
let issuer = env::var("JWT_ISSUER").unwrap_or_default();
JWTOptions {
algorithm: "RS256".to_string(),
expires_in,
issuer,
}
}
}
pub struct JWTUtils {
payload: Value,
private_key: String,
public_key: String,
options: JWTOptions,
}
// --- Helper Functions ---
fn base64url_encode<T: AsRef<[u8]>>(input: T) -> String {
URL_SAFE_NO_PAD.encode(input)
}
fn base64url_decode(input: &str) -> Result<Vec<u8>, base64::DecodeError> {
URL_SAFE_NO_PAD.decode(input)
}
// --- Implementation ---
impl JWTUtils {
pub fn new(
payload: Value,
private_key: Option<String>,
public_key: Option<String>,
options: Option<JWTOptions>,
) -> Result<Self, JWTError> {
let opts = options.unwrap_or_default();
let (priv_key, pub_key) = match (private_key, public_key) {
(Some(priv_k), Some(pub_k)) => (priv_k, pub_k),
(priv_k, pub_k) => {
let keys = CryptoUtils::load_keys_from_files("keys")
.map_err(|e| JWTError::Crypto(e.to_string()))?;
(
priv_k.unwrap_or(keys.private_key),
pub_k.unwrap_or(keys.public_key),
)
}
};
Ok(JWTUtils {
payload,
private_key: priv_key,
public_key: pub_key,
options: opts,
})
}
/// Create JWT header
fn create_header(&self) -> JWTHeader {
JWTHeader {
alg: self.options.algorithm.clone(),
typ: "JWT".to_string(),
}
}
/// Process payload with standard claims
fn process_payload(&self) -> Result<String, JWTError> {
let mut payload_obj = self
.payload
.as_object()
.ok_or_else(|| JWTError::Custom("Payload must be a JSON object".to_string()))?
.clone();
let now = Utc::now();
let iat = now.timestamp();
let exp = (now + Duration::seconds(self.options.expires_in)).timestamp();
payload_obj.insert("iat".to_string(), iat.into());
payload_obj.insert("exp".to_string(), exp.into());
payload_obj.insert("iss".to_string(), self.options.issuer.clone().into());
Ok(to_string(&payload_obj)?)
}
/// Sign the JWT components
fn sign(&self, header_base64: &str, payload_base64: &str) -> Result<String, JWTError> {
let signature_input = format!("{}.{}", header_base64, payload_base64);
let keypair = Rsa::private_key_from_pem(self.private_key.as_bytes())?;
let pkey = PKey::from_rsa(keypair)?;
let mut signer = Signer::new(MessageDigest::sha256(), &pkey)?;
signer.update(signature_input.as_bytes())?;
let signature = signer.sign_to_vec()?;
Ok(base64url_encode(&signature))
}
/// Create complete JWT
pub fn create_token(&self) -> Result<String, JWTError> {
let header = self.create_header();
let processed_payload = self.process_payload()?;
let header_base64 = base64url_encode(to_string(&header)?);
let payload_base64 = base64url_encode(processed_payload);
let signature = self.sign(&header_base64, &payload_base64)?;
Ok(format!(
"{}.{}.{}",
header_base64, payload_base64, signature
))
}
/// Verify JWT token signature
pub fn verify(&self, token: &str) -> bool {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return false;
}
let signature_input = format!("{}.{}", parts[0], parts[1]);
let signature = match base64url_decode(parts[2]) {
Ok(s) => s,
Err(_) => return false,
};
let key = match PKey::public_key_from_pem(self.public_key.as_bytes()) {
Ok(k) => k,
Err(_) => return false,
};
let mut verifier = match Verifier::new(MessageDigest::sha256(), &key) {
Ok(v) => v,
Err(_) => return false,
};
verifier.update(signature_input.as_bytes()).unwrap();
verifier.verify(&signature).unwrap_or(false)
}
/// Decode JWT token without verification
pub fn decode(token: &str) -> Result<(Value, Value), JWTError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() < 2 {
return Err(JWTError::InvalidTokenFormat(
"Token must have at least two parts".to_string(),
));
}
let header_json = String::from_utf8(base64url_decode(parts[0])?)
.map_err(|_| JWTError::InvalidTokenFormat("Header is not valid UTF-8".to_string()))?;
let payload_json = String::from_utf8(base64url_decode(parts[1])?)
.map_err(|_| JWTError::InvalidTokenFormat("Payload is not valid UTF-8".to_string()))?;
let header: Value = from_str(&header_json)?;
let payload: Value = from_str(&payload_json)?;
Ok((header, payload))
}
/// Check if token is expired
pub fn is_expired(token: &str) -> bool {
match Self::decode(token) {
Ok((_, payload)) => {
if let Some(exp) = payload.get("exp").and_then(|v| v.as_i64()) {
let now = Utc::now().timestamp();
exp < now
} else {
true // No expiration claim, consider it expired for safety
}
}
Err(_) => true, // Invalid token, consider it expired
}
}
/// A combined decode and verify function
pub fn decode_and_verify(token: &str) -> Result<(Value, Value), JWTError> {
let jwt = JWTUtils::new(Value::Null, None, None, None)?;
if !jwt.verify(token) {
return Err(JWTError::Validation(
"Signature verification failed".to_string(),
));
}
if Self::is_expired(token) {
return Err(JWTError::Validation("Token has expired".to_string()));
}
Self::decode(token)
}
/// Refresh a token
pub fn refresh_token(old_token: &str) -> Result<String, JWTError> {
let (_, payload_val) = Self::decode(old_token)?;
let mut payload_obj = payload_val
.as_object()
.ok_or_else(|| JWTError::Custom("Payload is not an object".to_string()))?
.clone();
payload_obj.remove("exp");
payload_obj.remove("iat");
// The keys must be loaded from files for this static method
let keys = CryptoUtils::load_keys_from_files("keys")
.map_err(|e| JWTError::Crypto(e.to_string()))?;
let jwt = JWTUtils::new(
Value::Object(payload_obj),
Some(keys.private_key),
Some(keys.public_key),
None,
)?;
jwt.create_token()
}
/// Validate that specific claims are present in the token
pub fn validate_claims(token: &str, required_claims: &[&str]) -> bool {
match Self::decode(token) {
Ok((_, payload)) => {
if let Some(payload_obj) = payload.as_object() {
required_claims
.iter()
.all(|claim| payload_obj.contains_key(*claim))
} else {
false
}
}
Err(_) => false,
}
}
}