LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.
Apache License 2.0
4.28k stars 340 forks source link

Proposal for new more generic Module trait #284

Open surdus opened 3 years ago

surdus commented 3 years ago

Current Module and ModuleT traits can accept only Tensor type as input. It can be not suitable for complex models that should accept not only one tensor. What about such trait for modules:

pub trait TypedModule {
    type Input;
    type Output;

    fn forward_t(&self, input: &Self::Input, train: bool) -> Result<Self::Output, TchError>;

    fn forward(&self, input: &Self::Input) -> Result<Self::Output, TchError> {
        self.forward_t(input, false)
    }
}

What changes:

For all current modules we can add such implementation:

impl<M> TypedModule for M
    where M: ModuleT
{
    type Input = Tensor;
    type Output = Tensor;

    fn forward_t(&self, input: &Self::Input, train: bool) -> Result<Self::Output, TchError> {
        Ok(self.forward_t(input, train))
    }
}
LaurentMazare commented 3 years ago

Sounds like an interesting idea. The main goal of the Module and ModuleT trait is to make it easy to chain layers that are functions taking as input a tensor and returning a tensor for which the current trait work well. It's a bit unclear to me what this parameterized version would buy us. In order to get a better sense of this, maybe you could define and use this trait in a crate that implements a model where this is actually useful (or even better in multiple models/use cases). Based on this we would have a better feeling of the generality of the thing, as well as the potential rough edges of the api.

lostmsu commented 3 years ago

I am not the OP, but one use case would be modules, that take multiple inputs or produce multiple outputs.