#machine-learning #neural-network #machine #learning #network

tsuga

An early-stage machine learning library for building neural networks in Rust

2 releases

0.1.1 Sep 20, 2020
0.1.0 Aug 24, 2020

#1021 in Machine learning

MIT license

38KB
566 lines

crates.io Documentation

Tsuga

An early-stage machine learning library in Rust

Tsuga is an early-stage machine learning library in Rust for building neural networks. It uses ndarray as the linear algebra backend, and operates primarily on two-dimensional f32 arrays (Array2<f32> types). At the moment, it's primary function has been for testing out various ideas for APIs, as an educational exercise, and probably isn't yet suitable for serious use. Most of the project's focus so far has been on the image-processing domain, although the tools and layout should generally applicable to higher/lower-dimensional datasets as well.

To use tsuga as a library, add the following to your Cargo.toml file:

[dependencies]
tsuga = "0.1"
ndarray = "0.13"

For development, I recommend cloning only the most recent version--the training data has been included in past commits, which can lead to unnecessarily large file sizes for the entire history. This can be done using

$ git clone --depth=1 https://github.com/quietlychris/tsuga.git

Fully-Connected Network Example for MNIST

Tsuga currently uses the Builder pattern for constructing fully-connected networks. Since networks are complex compound structures, this pattern helps to make the layout of the network explicit and modular.

The following is a reduced-code example of building a network to train on/evaluate the MNIST (or Fashion MNIST) data set. Including unpacking the MNIST binary files, this network achieves:

  • An accuracy of ~91.5% over 1000 iterations in 3.65 seconds
  • An accuracy of ~97.1% over 10,000 iterations in 29.43 seconds

This example can be run using $ cargo run --release --example mnist

use ndarray::prelude::*;
use tsuga::prelude::*;

fn main() {
    // Builds the MNIST data from a binary into ndarray Array2<f32> structures
    // Labels are built with one-hot encoding format
    // ([60_000, 784], [60_000, 10], [10_000, 784], [10_000, 10] )
    let (input, output, test_input, test_output) = mnist_as_ndarray();
    println!("Successfully unpacked the MNIST dataset into Array2<f32> format!");

    let mut layers_cfg: Vec<FCLayer> = Vec::new();
    let sigmoid_layer_0 = FCLayer::new("sigmoid", 128);
    layers_cfg.push(sigmoid_layer_0);
    let sigmoid_layer_1 = FCLayer::new("sigmoid", 64);
    layers_cfg.push(sigmoid_layer_1);

    let mut fcn = FullyConnectedNetwork::default(input, output)
        .add_layers(layers_cfg)
        .iterations(1000)
        .learnrate(0.01)
        .batch_size(200)
        .build();

    fcn.train();

    println!("Test input shape = {:?}", test_input.shape());
    println!("Test output shape = {:?}", test_output.shape());

    let test_result = fcn.evaluate(test_input);
    compare_results(test_result, test_output);
}

Dependencies

Tsuga uses the minifb to display sample images during development, which means you may need to add certain dependencies via

$ sudo apt install libxkbcommon-dev libwayland-cursor0 libwayland-dev

Dependencies

~18–27MB
~390K SLoC