5 releases
0.1.5 | Mar 19, 2023 |
---|---|
0.1.4 | Mar 17, 2023 |
0.1.3 | Mar 2, 2023 |
0.1.2 | Nov 26, 2022 |
0.1.0 | Nov 20, 2022 |
#794 in Math
42 downloads per month
91KB
2.5K
SLoC
microtensor
Automatic differentiation for tensor operations.
Requires Rust nightly.
Features
-
Safe auto-grad — Non-differentiable operations return a separate type that cannot be back-propagated, revealing gaps in your computation graph at compile time.
-
Broadcasting — Tensors with differing but compatible shapes get broadcasted to matching dimensions automatically for most operations.
-
Arbitrary inner types — Tensors can store almost any data type and compute gradients for any inner type that satisfies scalar::Real.
-
Zero-copy views — Tensors may be sliced, indexed, reshaped, transposed and broadcasted without actually copying any data in most situations.
-
Graph recycling — Computation graphs, created by tracing an eager computation, can be reevaluated at a later time with new input data. They can also be serialized and loaded elsewhere, without access to the original code.
Examples
Evaluating and minimizing a non-linear function:
use microtensor::{prelude::*, Tensor};
// Create variables from tensors
let w = Tensor::randn(&[2, 16]).trained();
let b = Tensor::zeros(&[16]).trained();
for _ in 0..100 {
// Do some computation
let x = Tensor::vec(&[1.0, 2.0]).tracked();
let loss = ((x.mm(&w) + &b).sigmoid() - 0.5).sqr().mean(0);
// Compute gradients
loss.backward();
// Nudge w and b in order to minimize loss
for mut param in loss.parameters() {
param -= param.grad().unwrap() * 0.01;
}
// Reset gradients
loss.reset();
}
Automatic broadcasting:
use microtensor::{prelude::*, Tensor};
let a = Tensor::arrange(&[2, 16], 0., 1.);
let b = Tensor::ones(&[2]);
let c = &a - b.unsqueeze(-1) + 1.;
assert_eq!(a, c);
Generic return types:
use microtensor::{prelude::*, Tensor};
let t = Tensor::<f32>::randn(&[16]);
let _a: u8 = t.argmax(0).item();
let _b: u16 = t.argmax(0).item(); // argmax will produce a Tensor<u16> here
Optional features
Some features can be toggled in your Cargo.toml
.
unsafe
(default) — Accelerated matrix math using [matrixmultiply] crate.threading
(default) — Thread safety & multi-threaded operation over batch dimensions.
More examples
Check the /examples
folder for more example code.
License
MIT
Dependencies
~2.3–9.5MB
~77K SLoC