Open wcshds opened 7 months ago
As a reference, the plan is not to support everything such that it can be a 1-to-1 map with PyTorch. The PyTorchFileRecorder
is there to help make the transition easier to import your pre-trained weights. There could be a couple different ways to work around this, such as modifying the pytorch model when saved.
But for basic supported layers such as LSTM it's a bit different. Might be something we need to add.
If you could provide a minimal code, we will look into this. PyTorch screen shot appears to store weights as an array, which is supported by PyTorchFileRecorder.
If you could provide a minimal code, we will look into this. PyTorch screen shot appears to store weights as an array, which is supported by PyTorchFileRecorder.
Sure.
Rust code:
use burn::{
backend::{ndarray::NdArrayDevice, NdArray},
module::Module,
nn::{Linear, LinearConfig, Lstm, LstmConfig},
record::{FullPrecisionSettings, Recorder},
tensor::{backend::Backend, Tensor},
};
use burn_import::pytorch::PyTorchFileRecorder;
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
lstm: Lstm<B>,
linear: Linear<B>,
}
impl<B: Backend> Model<B> {
pub fn new(device: &B::Device) -> Self {
Self {
lstm: LstmConfig::new(10, 20, true).init(device),
linear: LinearConfig::new(20, 30).init(device),
}
}
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let (_, out) = self.lstm.forward(input, None);
self.linear.forward(out)
}
}
fn main() {
type Backend = NdArray;
let device = NdArrayDevice::Cpu;
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("./example.pt".into(), &device)
.expect("Should decode state successfully");
let model: Model<Backend> = Model::new(&device).load_record(record);
}
Python code:
import torch
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(10, 20)
self.linear = nn.Linear(20, 30)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out, _ = self.lstm(input)
return self.linear(out)
if __name__ == "__main__":
model = Model()
input = torch.randn(2, 8, 10)
res = model(input)
print(res.shape)
print([name for name, _ in model.named_parameters()])
torch.save(model.state_dict(), "./example.pt")
@antimora @laggui Therefore, importing LSTM's weights from PyTorch should do some remap like this
lstm.weight_ih_l0 => lstm.input_gate.input_transform.weight => lstm.forget_gate.input_transform.weight => lstm.cell_gate.input_transform.weight => lstm.output_gate.input_transform.weight lstm.bias_ih_l0 => lstm.input_gate.input_transform.bias => lstm.forget_gate.input_transform.bias => lstm.cell_gate.input_transform.bias => lstm.output_gate.input_transform.bias lstm.weight_hh_l0 => lstm.input_gate.hidden_transform.weight => lstm.forget_gate.hidden_transform.weight => lstm.cell_gate.hidden_transform.weight => lstm.output_gate.hidden_transform.weight lstm.bias_hh_l0 => lstm.input_gate.hidden_transform.bias => lstm.forget_gate.hidden_transform.bias => lstm.cell_gate.hidden_transform.bias => lstm.output_gate.hidden_transform.bias
Here is what I would recommend you to unblock you quickly. It will be a while since we handle this use case.
Create a corresponding PyTorch module for lstm type matching the source attributes, so you can load the source PyTorch record without issues. I recommend the following module structure to match the source:
#[derive(Module, Debug)]
pub struct LstmIntermediate<B: Backend> {
pub weight_ih_l0: Param<Tensor<B, 2>>,
pub weight_hh_l0: Param<Tensor<B, 2>>,
pub bias_ih_l0: Param<Tensor<B, 1>>,
pub bias_hh_l0: Param<Tensor<B, 1>>,
}
impl <B: Backend> LstmIntermediate<B> {
fn into_lstm(self) -> Lstm<B> {
...
}
}
Once you have it loaded, you can extract the data and chunk them up. You can your logic in into_lstm
.
After you made the transformation, just save in Burn's record format so you can load without transformation next time.
Description
I encountered difficulties when importing the weights of an LSTM from PyTorch. This is because PyTorch often tends to concatenate different weights together, resulting in significant differences between the structure of LSTM weights in PyTorch and those in Burn. Therefore, I think it is necessary for
PyTorchFileRecorder
to provide a method to chunk the weights before importing them into Burn.In Burn the weights of lstm are saved through eight Linears.