1 unstable release
0.1.0 | Dec 29, 2024 |
---|
#154 in Machine learning
95 downloads per month
67KB
2K
SLoC
Ferrite: A Deep Learning Library in Rust
A deep learning framework written in pure Rust, inspired by PyTorch. Used this to learn Rust and refine DL concepts.
Features
- Dynamic Computational Graph: Build and modify neural networks on the fly
- Automatic Differentiation: Automatic computation of gradients through the
backward()
method - Efficient Tensor Operations: Fast operations with broadcasting support
- Memory Safety: Leveraging Rust's ownership model for safe and efficient memory management
- Rich Tensor API: Comprehensive set of tensor operations including:
- Element-wise operations (add, multiply, divide)
- Matrix operations (matmul)
- Reduction operations (sum, mean, product)
- Shape manipulation (reshape, transpose, squeeze/unsqueeze)
- Broadcasting support
Quick Start
use ferrite::*;
use ndarray::array;
fn main() {
// Create tensors with gradient tracking
let x = Tensor::from_ndarray(&array![[1,2,3],[4,5,6]], Some(true));
let y = Tensor::from_ndarray(&array![[1,1,1]], Some(true));
// Perform operations
let z = x.mul_tensor(&y);
let mut f = z.sum();
// Compute gradients
f.backward();
// Access gradients
println!("grad x: {:?}", x.grad());
println!("grad y: {:?}", y.grad());
}
Architecture
ferrite is built with a modular architecture:
- TensorStorage: Core tensor storage and operations
- Tensor: High-level tensor interface with autograd support
- Module: Base trait for neural network modules
- Autograd: Automatic differentiation engine
- Parameter: Trainable parameters for neural networks
Implementation Details
- Uses
Rc<RefCell<>>
for shared ownership and interior mutability - Implements efficient broadcasting with stride-based computation
- Supports n-dimensional tensors with arbitrary shape
- Provides complete automatic differentiation for supported operations
- Uses traits for clean abstraction of tensor operations
Usage Examples
Creating Tensors
// Create tensor filled with zeros
let zeros = Tensor::zeros(vec![2, 3], Some(true));
// Create tensor from ndarray
let data = Tensor::from_ndarray(&array![[1.0, 2.0], [3.0, 4.0]], Some(true));
Neural Network Module (Will update when finished)
impl SimpleNetwork {
fn new() -> Self {
let module = Module::new();
SimpleNetwork { module }
}
}
impl Segment for SimpleNetwork {
fn forward(input: Tensor) -> Tensor {
// Implement your network logic here
}
}
Future Plans
- Finish building Neural Network interface
- Add more operations
- Add CUDA support
- Optimize performance for large tensors
- Add more loss functions
- Implement data loading utilities
- Add serialization support
Acknowledgments
- PyTorch (for inspiration)
- Claude (for teaching me Rust)
Dependencies
~1.5MB
~30K SLoC