sjb3d / descent

Toy library for neural networks in Rust using Vulkan compute shaders
MIT License
80 stars 5 forks source link

descent

Toy library for neural networks in Rust using Vulkan compute shaders.

Features

Example Network

The top-level API of neural network building blocks can be used to compactly describe multi-layer networks. Here is a small convolutional neural network with dropout and (leaky) ReLU activation using this API:

struct ConvNet {
    conv1: Conv2D,
    conv2: Conv2D,
    fc1: Dense,
    fc2: Dense,
}

impl ConvNet {
    fn new(env: &mut Environment) -> Self {
        // create and store parameters for layers that require them
        let c1 = 16;
        let c2 = 32;
        let hidden = 128;
        Self {
            conv1: Conv2D::builder(1, c1, 3, 3).with_pad(1).build(env),
            conv2: Conv2D::builder(c1, c2, 3, 3)
                .with_pad(1)
                .with_groups(2)
                .build(env),
            fc1: Dense::builder(7 * 7 * c2, hidden).build(env),
            fc2: Dense::builder(hidden, 10).build(env),
        }
    }
}

impl Module for ConvNet {
    fn eval<'s>(&self, input: DualArray<'s>, ctx: &EvalContext) -> DualArray<'s> {
        // generates ops for the value (forwards) and gradient (backwards) through the layers
        input
            .apply(&self.conv1, ctx)
            .leaky_relu(0.01)
            .max_pool2d((2, 2), (2, 2))
            .apply(&self.conv2, ctx)
            .leaky_relu(0.01)
            .max_pool2d((2, 2), (2, 2))
            .flatten()
            .apply(&Dropout::new(0.5), ctx)
            .apply(&self.fc1, ctx)
            .leaky_relu(0.01)
            .apply(&self.fc2, ctx)
    }
}

See the fashion_mnist example for more networks using this API.

Examples

Please follow the link in the name of each example to show a more detailed description of each one.

Name Description
array_api Demonstrates the low-level Array API for building computation graphs. See the README for more details.
fashion_mnist Trains a few different network types on the Fashion-MNIST dataset. Demonstrates the use of anti-aliasing during max pooling for improved accuracy. See the README for a comparison of network performance.
image_fit Overfits a few different network types to a single RGB image. Compares ReLU with positional encoding to a SIREN network. Update: now also compares to a multi-level hash encoding.

Dependencies

The following crates have been very useful to develop this project:

Potential Future Work