Git Product home page Git Product logo

autograph's Introduction

LicenseBadge DocsBadge Build Status

autograph

A machine learning library for Rust.

To use autograph in your crate, add it as a dependency in Cargo.toml:

[dependencies]
autograph = "0.1.1"

Requirements

Tests

  • To check that you have a valid device, run cargo test device_new --features device_tests.
  • Run all the tests with cargo test --features "full device_tests".

Custom Shader Code

You can write your own shaders and execute them with autograph.

// shader/src/lib.rs

// Declare the push constants. Use `#[repr(C)]` to ensure that fields
// are not reordered.
#[repr(C)]
pub struct PushConsts {
    n: u32,
}

/// Computes `y' = `a` + `b`
///
/// `threads` can be up to 3 dimensions (x, y, z). This is the size of the `WorkGroup`. Generally
/// this should be a multiple of the hardware specific size, NVidia refers to this as the
/// `warp size`, which for NVidia is often 32 but sometimes 64. For AMD this is generally 64. 64
/// is a good default. Note that autograph will automatically choose the number of work groups to
/// execute given the global size, so it is not necessary for the function submitting the shader
/// to know the work group size.
///
/// # Note
/// autograph does check the size of the push constants, as well as the mutability of buffers. It
/// DOES NOT check their types. For example, a buffer can be declared like `&[u32]` but bound to a
/// `Slice<u8>`.
#[allow(unused)]
#[spirv(compute(threads(64)))]
pub fn add(
    // This is the unique id of the invocation, and is 3D (x, y, z) even though we are just using x.
    // This tells the invocation what index to compute.
    #[spirv(global_invocation_id)] global_id: UVec3,
    // Buffer `a`. As of now, `storage_buffer`, `descriptor_set`, and `binding` must all be
    // specified.
    // Because this is not modified, it can be bound to a `Slice`.
    #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[u32],
    // Buffer `b`.
    // Because this is not modified, it can be bound to a `Slice`.
    #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] b: &[u32],
    // Buffer `y`, the output.
    // This can only be bound to a `SliceMut`.
    #[spirv(storage_buffer, descriptor_set = 0, binding = 2)] y: &mut [u32],
    // Push constants, ie additional arguments passed at runtime.
    #[spirv(push_constant)] push_consts: &PushConsts,
) {
    let gid = global_id.x as usize;
    // Only process up to n, which is the length of the buffers.
    if global_id.x < push_consts.n {
        // The indexing operation is implemented by rust-gpu, and is the only way to access
        // the data, ie using `[T]::get()` and dereferencing &T or *const T will fail to compile.
        y[gid] = a[gid] + b[gid];
    }
}

// main.rs

/// Adds `a` to `b`.
fn add(a: Slice<u32>, b: Slice<u32>) -> Result<Buffer<u32>> {
    if a.len() != b.len() {
        return Err(anyhow!("{} != {}", a.len(), b.len()));
    }
    // Typically use `Buffer::alloc` here but it's unsafe.
    // `zeros()` runs a shader to fill the buffer, so it's unnecessary if it will be overwritten.
    let mut y = Buffer::zeros(a.device(), a.len())?;
    // The shader executes in "WorkGroups", which in the shader is defined to be [64, 1, 1]. This
    // means that even though we have just 1 item to process, it will actually run 64 invocations
    // aka threads. We have to pass `n` to prevent the extra invocations from writing outside of
    // the buffer.
    let n = y.len() as u32;
    let builder = module()?
        .compute_pass("add")?
        // `storage_buffer` at binding 0, must not be modified in the shader.
        .slice(a)?
        // `storage_buffer` at binding 1, must not be modified in the shader.
        .slice(b)?
        // `storage_buffer` at binding 2
        .slice_mut(y.as_slice_mut())?
        .push(n)?;
    unsafe {
        // Enqueues the shader with global size [n, 1, 1].
        // This method validates the arguments, and compiles the module for the device on first
        // use. Otherwise, this doesn't block, the internal device thread will submit work to the
        // device driver when it is ready.
        builder.submit([n, 1, 1])?;
    }
    Ok(y)
}

#[tokio::main]
async fn main() -> Result<()> {
    let device = Device::new()?;
    let x_in = [2];
    // Here we create a Slice<u32> from a &[u32].
    // We could also create a Buffer from a Vec, without copying.
    let x = Slice::from(x_in.as_ref())
        // Note that Host -> Device transfers are non-blocking, not async.
        .into_device(device.clone())
        .await?;
    /// Get the result of the addition.
    let y = add(x.as_slice(), x.as_slice())?;
    // Print out the result!
    println!("{:?} + {:?} = {:?}", x_in, x_in, y.read().await?.as_slice());
    Ok(())
}

See the Hello Compute example.

Machine Learning

KMeans

// Create the device.
let device = Device::new()?;
// Create the dataset.
let iris = Iris::new();
// The flower dimensions are the inputs to the model.
let x_array = iris.dimensions();
// Select only Petal Length + Petal Height
// These are the primary dimensions and it makes plotting easier.
let x_array = x_array.slice(&s![.., 2..]);
// Create the KMeans model.
let kmeans = KMeans::new(iris.class_names().len())
    .into_device(device.clone())
    .await?;
// For small datasets, we can load the entire dataset into the device.
// For larger datasets, the data can be streamed as an iterator.
let x = CowTensor::from(x_array.view())
    .into_device(device)
    // Note that despite the await this will resolve immediately.
    // Host -> Device transfers are batched with other operations
    // asynchronously on the device thread.
    .await?;
// Construct a trainer.
let mut trainer = KMeansTrainer::from(kmeans);
// Intialize the model (KMeans++).
// Here we provide an iterator of n iterators, such that the trainer can
// visit the data n times. In this case, once for each centroid.
trainer.init(|n| std::iter::from_fn(|| Some(once(Ok(x.view().into())))).take(n))?;
// Train the model (1 epoch).
trainer.train(once(Ok(x.view().into())))?;
// Get the model back.
let kmeans = KMeans::from(trainer);
// Get the trained centroids.
// For multiple reads, batch them by getting the futures first.
let centroids_fut = kmeans.centroids()
    // The centroids are in a FloatArcTensor, which can either be f32 or bf16.
    // This will convert to f32 if necessary.
    .cast_to::<f32>()?
    .read();
// Get the predicted classes.
let pred = kmeans.predict(&x.view().into())?
    .into_dimensionality()?
    .read()
// Here we wait on all previous operations, including centroids_fut.
    .await?;
// This will resolve immediately.
let centroids = centroids_fut.await?;
// Get the flower classes from the dataset.
let classes = iris.classes().map(|c| *c as u32);
// Plot the results to "plot.png".
// Note that since KMeans is an unsupervised method the predicted classes will be arbitrary and
// not align to the order of the true classes (ie the colors won't be the same in the plot).
plot(&x_array.view(), &classes.view(), &pred.as_array(), &centroids.as_array())?;

Plot See the KMeans Iris example.

Neural Networks

#[derive(Layer, Forward, Clone, Debug)]
struct Lenet5 {
    #[autograph(layer)]
    conv1: Conv,
    #[autograph(layer)]
    relu1: Relu,
    #[autograph(layer)]
    pool1: MaxPool,
    #[autograph(layer)]
    conv2: Conv,
    #[autograph(layer)]
    relu2: Relu,
    #[autograph(layer)]
    pool2: MaxPool,
    #[autograph(layer)]
    dense1: Dense,
    #[autograph(layer)]
    relu3: Relu,
    #[autograph(layer)]
    dense2: Dense,
    #[autograph(layer)]
    relu4: Relu,
    #[autograph(layer)]
    dense3: Dense,
}

impl Lenet5 {
    fn new() -> Result<Self> {
        let conv1 = Conv::from_inputs_outputs_kernel(1, 6, [5, 5]);
        let relu1 = Relu::default();
        let pool1 = MaxPool::from_kernel([2, 2])
            .with_strides(2)?;
        let conv2 = Conv::from_inputs_outputs_kernel(6, 16, [5, 5]);
        let relu2 = Relu::default();
        let pool2 = MaxPool::from_kernel([2, 2])
            .with_strides(2)?;
        let dense1 = Dense::from_inputs_outputs(256, 120);
        let relu3 = Relu::default();
        let dense2 = Dense::from_inputs_outputs(120, 84);
        let relu4 = Relu::default();
        let dense3 = Dense::from_inputs_outputs(84, 10)
            .with_bias(true)?;
        Ok(Self {
            conv1,
            relu1,
            pool1,
            conv2,
            relu2,
            pool2,
            dense1,
            relu3,
            dense2,
            relu4,
            dense3,
        })
    }
}

See the Neural Network MNIST example.

Benchmarks

NVIDIA GeForce GTX 1060 with Max-Q Design

+-----------+------------+---------------+-----------------------+----------------------------------+
| Library   | Best Epoch | Best Accuracy | Time To Best Accuracy | Mean Epoch Time to Best Accuracy |
+===========+============+===============+=======================+==================================+
| autograph | 69         | 99.04%        | 127.38s               | 1.85s                            |
+-----------+------------+---------------+-----------------------+----------------------------------+
| tch       | 32         | 99.12%        | 22.03s                | 688.31ms                         |
+-----------+------------+---------------+-----------------------+----------------------------------+

See the Neural Network benchark.

Profiling

Currently requires nightly and feature "profile". Set the AUTOGRAPH_PROFILE environmental variable to 1 or True to produce a table of statistics for compute passes that are executed.

AUTOGRAPH_PROFILE=1 cargo +nightly run --feature profile

Will create a file "autograph_profile_summary.txt" like this:

+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| Module                          | Entry                              | Time %  | Invocations | Mean Time | Total Time |
+=================================+====================================+=========+=============+===========+============+
| gemm_f32                        | main                               | 84.75 % | 61103787    | 145.00ns  | 8.94s      |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| core                            | kernel::im2col_2d_convolution_f32  | 8.39 %  | 1306434     | 677.00ns  | 885.21ms   |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| core                            | pool::max_pool_2d_backward_f32     | 3.58 %  | 2134342     | 176.00ns  | 377.61ms   |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| core                            | reorder::as_standard_layout_6d_u32 | 0.97 %  | 717356446   | 0.00ns    | 102.39ms   |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| core                            | fill::fill_u32                     | 0.69 %  | 1141592100  | 0.00ns    | 73.20ms    |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| reduce_sum_final_f32            | main                               | 0.32 %  | 50226       | 661.00ns  | 33.25ms    |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| core                            | pool::max_pool_indices_2d_f32      | 0.31 %  | 10233126    | 3.00ns    | 33.04ms    |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| core                            | activation::relu_backward_f32      | 0.31 %  | 351834488   | 0.00ns    | 32.89ms    |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| core                            | activation::relu_f32               | 0.25 %  | 398674888   | 0.00ns    | 26.86ms    |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| scaled_add_f32                  | main                               | 0.14 %  | 33371028    | 0.00ns    | 14.45ms    |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| cross_entropy_loss_f32_64       | main                               | 0.06 %  | 107890      | 60.00ns   | 6.53ms     |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| bias_backward_f32               | main                               | 0.05 %  | 49566       | 111.00ns  | 5.53ms     |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| core                            | cast::scale_u8_f32                 | 0.05 %  | 16717490    | 0.00ns    | 5.50ms     |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| core                            | pool::max_pool_2d_f32              | 0.04 %  | 1360260     | 3.00ns    | 4.25ms     |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| reduce_argmax_final_f32         | main                               | 0.03 %  | 107890      | 24.00ns   | 2.66ms     |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| one_hot_u8_f32                  | main                               | 0.02 %  | 107890      | 22.00ns   | 2.39ms     |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| accuracy_u8                     | main                               | 0.02 %  | 107890      | 19.00ns   | 2.10ms     |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| cross_entropy_loss_backward_f32 | main                               | 0.02 %  | 97630       | 18.00ns   | 1.78ms     |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+
| scaled_cast_f32_f32             | main                               | 0.00 %  | 132         | 31.00ns   | 4.10µs     |
+---------------------------------+------------------------------------+---------+-------------+-----------+------------+

Note

Specify the profile feature for autograph if it is a dependency as autograph/profile.

Developement Platforms

  1. Ubuntu 18.04 | (Vulkan) NVidia GeForce GTX 1060 with Max-Q Design
  2. Wondows 10 Home | (Vulkan + DX12) AMD RX 580 / (DX12) Microsoft Basic Render Driver.

Shaders are tested on Github Actions:

  • Windows Server 2019 | (DX12) Microsoft Basic Render Driver.

Metal

Shaders are untested on Metal / Apple platforms. If you have problems, please create an issue!

License

Dual-licensed to be compatible with the Rust project.

Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 or the MIT license http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms.

Contribution

Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions.

autograph's People

Contributors

albertogp avatar charles-r-earp avatar nkconnor avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.