tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.94k stars 443 forks source link

Provide a method to chunk the weights before importing using PyTorchFileRecorder #1632

Open wcshds opened 7 months ago

wcshds commented 7 months ago

Description

I encountered difficulties when importing the weights of an LSTM from PyTorch. This is because PyTorch often tends to concatenate different weights together, resulting in significant differences between the structure of LSTM weights in PyTorch and those in Burn. Therefore, I think it is necessary for PyTorchFileRecorder to provide a method to chunk the weights before importing them into Burn.

In Burn the weights of lstm are saved through eight Linears.

laggui commented 7 months ago

As a reference, the plan is not to support everything such that it can be a 1-to-1 map with PyTorch. The PyTorchFileRecorder is there to help make the transition easier to import your pre-trained weights. There could be a couple different ways to work around this, such as modifying the pytorch model when saved.

But for basic supported layers such as LSTM it's a bit different. Might be something we need to add.

antimora commented 7 months ago

If you could provide a minimal code, we will look into this. PyTorch screen shot appears to store weights as an array, which is supported by PyTorchFileRecorder.

wcshds commented 7 months ago

If you could provide a minimal code, we will look into this. PyTorch screen shot appears to store weights as an array, which is supported by PyTorchFileRecorder.

Sure.

Rust code:

use burn::{
    backend::{ndarray::NdArrayDevice, NdArray},
    module::Module,
    nn::{Linear, LinearConfig, Lstm, LstmConfig},
    record::{FullPrecisionSettings, Recorder},
    tensor::{backend::Backend, Tensor},
};
use burn_import::pytorch::PyTorchFileRecorder;

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
    lstm: Lstm<B>,
    linear: Linear<B>,
}

impl<B: Backend> Model<B> {
    pub fn new(device: &B::Device) -> Self {
        Self {
            lstm: LstmConfig::new(10, 20, true).init(device),
            linear: LinearConfig::new(20, 30).init(device),
        }
    }

    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
        let (_, out) = self.lstm.forward(input, None);
        self.linear.forward(out)
    }
}

fn main() {
    type Backend = NdArray;
    let device = NdArrayDevice::Cpu;

    let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
        .load("./example.pt".into(), &device)
        .expect("Should decode state successfully");

    let model: Model<Backend> = Model::new(&device).load_record(record);
}

Python code:

import torch
from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.lstm = nn.LSTM(10, 20)
        self.linear = nn.Linear(20, 30)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out, _ = self.lstm(input)
        return self.linear(out)

if __name__ == "__main__":
    model = Model()
    input = torch.randn(2, 8, 10)
    res = model(input)
    print(res.shape)
    print([name for name, _ in model.named_parameters()])

    torch.save(model.state_dict(), "./example.pt")

image

wcshds commented 7 months ago

@antimora @laggui Therefore, importing LSTM's weights from PyTorch should do some remap like this

lstm.weight_ih_l0 => lstm.input_gate.input_transform.weight
                  => lstm.forget_gate.input_transform.weight
                  => lstm.cell_gate.input_transform.weight
                  => lstm.output_gate.input_transform.weight

lstm.bias_ih_l0   => lstm.input_gate.input_transform.bias
                  => lstm.forget_gate.input_transform.bias
                  => lstm.cell_gate.input_transform.bias
                  => lstm.output_gate.input_transform.bias

lstm.weight_hh_l0 => lstm.input_gate.hidden_transform.weight
                  => lstm.forget_gate.hidden_transform.weight
                  => lstm.cell_gate.hidden_transform.weight
                  => lstm.output_gate.hidden_transform.weight

lstm.bias_hh_l0   => lstm.input_gate.hidden_transform.bias
                  => lstm.forget_gate.hidden_transform.bias
                  => lstm.cell_gate.hidden_transform.bias
                  => lstm.output_gate.hidden_transform.bias
antimora commented 7 months ago

Here is what I would recommend you to unblock you quickly. It will be a while since we handle this use case.

Create a corresponding PyTorch module for lstm type matching the source attributes, so you can load the source PyTorch record without issues. I recommend the following module structure to match the source:

#[derive(Module, Debug)]
pub struct LstmIntermediate<B: Backend>  {
  pub weight_ih_l0: Param<Tensor<B, 2>>,
  pub weight_hh_l0: Param<Tensor<B, 2>>,
  pub bias_ih_l0: Param<Tensor<B, 1>>,
  pub bias_hh_l0: Param<Tensor<B, 1>>,
}

impl <B: Backend> LstmIntermediate<B> {
   fn into_lstm(self) -> Lstm<B> {
      ...
   }
}

Once you have it loaded, you can extract the data and chunk them up. You can your logic in into_lstm.

After you made the transformation, just save in Burn's record format so you can load without transformation next time.