huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
15.65k stars 931 forks source link

Feature Request: Implement a Keras-Like API #1153

Open bernardo-sb opened 1 year ago

bernardo-sb commented 1 year ago

Feature Request: Implement a Keras-Like API

Description:

I would like to propose the addition of a high-level Keras-like API to our Rust crate. This API would provide a more intuitive and user-friendly way to define, compile, and train neural network models.

Motivation:

Currently candle provides a low-level API for building neural network models. This Keras-like API would greatly enhance the usability of the library, making it more accessible to a wider audience of users.

Proposed Changes:

Example I've implemented a sketch from the mnist forward example:

use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module};

struct Sequential {
    layers: Vec<Linear>,
}

impl Sequential {
    fn new() -> Self {
        Sequential {
            layers: Vec::new(),
        }
    }

    fn add(&mut self, layer: Linear) {
        self.layers.push(layer);
    }

    fn compile(&self) -> Model {
        Model::new( &self.layers)
    }
}

struct Model {
    layers: Vec<Linear>,
}

impl Model {
    fn new(layers: &Vec<Linear>) -> Model {
        Model { layers: layers.clone() }
    }

    fn forward(&self, image: &Tensor) -> Result<Tensor> {
        let mut x = image.clone();
        for layer in &self.layers {
            x = layer.forward(&x)?;
            x = x.relu()?;
        }
        Ok(x)
    }
}

fn main() -> Result<()> {
    let device = Device::Cpu;

    let mut model = Sequential::new();
    model.add(Linear::new(
        Tensor::randn(0f32, 1.0, (100, 784), &device)?,
        Some(Tensor::randn(0f32, 1.0, (100,), &device)?),
    ));
    model.add(Linear::new(
        Tensor::randn(0f32, 1.0, (10, 100), &device)?,
        Some(Tensor::randn(0f32, 1.0, (10,), &device)?),
    ));

    let compiled_model = model.compile();

    let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;

    let digit = compiled_model.forward(&dummy_image)?;
    println!("Digit: {:?}", digit);
    Ok(())
}

I believe that implementing a Keras-like API would greatly enhance the usability and appeal of candle. This feature can make it easier for users to define and train neural network models, making our crate more accessible and user-friendly. We welcome feedback and discussions on this feature proposal.

LaurentMazare commented 1 year ago

That sounds like a great thing to do. candle-nn aims at providing a low level api similar to PyTorch so there is certainly room for a higher level api (or multiple ones) similar to lightning or keras.

What I would suggest is doing this in a separate repo and crate, this will make it easier to iterate on the design and implementation and would avoid tying the design too much with candle internals. We can keep this issue open and I'll tag it with "help wanted" in case it attracts interested folks and then you/others can decide on where to create the repo and start building the thing.

EricLBuehler commented 12 months ago

@bernardo-sb , if/when you implement this, please let me know as I think this would be a great way to add LoRA to a model! I would be open to working with you or whoever develops this crate to seamlessly integrate with candle-lora.

shivance commented 11 months ago

Is this feature request still looking for contributions @LaurentMazare @EricLBuehler ? I'm interested and started working on it at https://github.com/shivance/candle-keras . I plan to make it very similar to Keras' API. I'm very new to Rust and might need some guidance 😄

LaurentMazare commented 11 months ago

Really up to you, I don't know of any such higher level api at the moment so I guess there is room for this but you should certainly go with building something that ends up being useful to you (or try and gather feedback from potential users). Maybe you want to build some simple version to start with and then advertise it a bit to see the interest that it raises and the suggestions that you get from potential users. Similar to keras sounds very reasonable but I'm no expert there. Once you have something a bit polished, would certainly be good to add it under the "external resources" section of the readme though maybe at some stage we will just split that out as some form of "awesome candle" external repo.

EricLBuehler commented 11 months ago

Sounds interesting! When it is ready, I would be happy to contribute a "convert to LoRA method". Perhaps you can build it using the candle-lora trait swapping mechanism?