tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.96k stars 444 forks source link

The `burn load_record` operation seems to have damaged the structure of the model. #1655

Closed zemelLeong closed 7 months ago

zemelLeong commented 7 months ago

The project: https://github.com/zemelLeong/table-structure-recognition.git

Code snippet:

pub fn get_downsample_len(&self) -> usize {
    self.layer1.layers.first().unwrap().downsample.as_ref().unwrap().layers.len()
}

let model = LoreDetectModel::init(device);
println!("LoreModel new_with init downsample layers len: {}", model.get_downsample_len());
let model = model.load_record(record.model);
println!("LoreModel new_with load_record downsample layers len: {}", model.get_downsample_len());

Out logs:

LoreModel new_with init downsample layers len: 2
LoreModel new_with load_record downsample layers len: 0

The length of downsample layers has become 0 after executing load_record, so I believe that the load_record operation has damaged the structure of the model.

laggui commented 7 months ago

Looking at the full linked code, it might be an interaction with the PyTorchFileRecorder

impl<B: Backend> LoreModel<B> {
    pub fn new_with(model_path: &str, device: &B::Device) -> Self {
        let load_args = LoadArgs::new(model_path.into())
            .with_key_remap("model.ax.2", "model.ax.1")
            .with_key_remap("model.ax.4", "model.ax.2")
            .with_key_remap("model.ax.6", "model.ax.3")
            .with_key_remap("model.ax.8", "model.ax.4")
            .with_key_remap("model.ax", "model.ax.layers")
            .with_key_remap("model.cr.2", "model.cr.1")
            .with_key_remap("model.cr.4", "model.cr.2")
            .with_key_remap("model.cr.6", "model.cr.3")
            .with_key_remap("model.cr.8", "model.cr.4")
            .with_key_remap("model.cr", "model.cr.layers")
            .with_key_remap("model.hm.2", "model.hm.1")
            .with_key_remap("model.hm.4", "model.hm.2")
            .with_key_remap("model.hm.6", "model.hm.3")
            .with_key_remap("model.hm.8", "model.hm.4")
            .with_key_remap("model.hm", "model.hm.layers")
            .with_key_remap("model.reg.2", "model.reg.1")
            .with_key_remap("model.reg.4", "model.reg.2")
            .with_key_remap("model.reg.6", "model.reg.3")
            .with_key_remap("model.reg.8", "model.reg.4")
            .with_key_remap("model.reg", "model.reg.layers")
            .with_key_remap("model.st.2", "model.st.1")
            .with_key_remap("model.st.4", "model.st.2")
            .with_key_remap("model.st.6", "model.st.3")
            .with_key_remap("model.st.8", "model.st.4")
            .with_key_remap("model.st", "model.st.layers")
            .with_key_remap("model.wh.2", "model.wh.1")
            .with_key_remap("model.wh.4", "model.wh.2")
            .with_key_remap("model.wh.6", "model.wh.3")
            .with_key_remap("model.wh.8", "model.wh.4")
            .with_key_remap("model.wh", "model.wh.layers")
            .with_key_remap("model.adaptionU1", "model.adaption_u1")
            .with_key_remap(
                "processor.stacker.tsfm.decoder.linear.2",
                "processor.stacker.tsfm.decoder.linear.1",
            )
            .with_key_remap(
                "processor.stacker.tsfm.decoder.linear",
                "processor.stacker.tsfm.decoder.linear.layers",
            )
            .with_key_remap(
                "processor.tsfm_axis.decoder.linear.2",
                "processor.tsfm_axis.decoder.linear.1",
            )
            .with_key_remap(
                "processor.tsfm_axis.decoder.linear",
                "processor.tsfm_axis.decoder.linear.layers",
            )
            .with_key_remap(
                "processor.stacker.logi_encoder.2",
                "processor.stacker.logi_encoder.1",
            )
            .with_key_remap(
                "processor.stacker.logi_encoder",
                "processor.stacker.logi_encoder.layers",
            )
            .with_key_remap(r"(model.layer\d)", r"$1.layers")
            .with_key_remap(r"(model.deconv_layers\d)", r"$1.layers")
            .with_key_remap("downsample.", "downsample.layers");
        let record: LoreModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
            .load(load_args, device)
            .unwrap();

        let model = LoreDetectModel::init(device);
        println!("LoreModel new_with init downsample layers len: {}", model.get_downsample_len());
        let model = model.load_record(record.model);
        println!("LoreModel new_with load_record downsample layers len: {}", model.get_downsample_len());
        let processor = LoreProcessModel::init(device).load_record(record.processor);

        Self { model, processor }
    }

If the records contain no matching record for the vector of downsample layers than the vector will be initialized as empty.

Can you check if the record contains any downsample layers before loading the record with your model?

zemelLeong commented 7 months ago

I have checked the record contains downsample layers.

This is alll keys: https://github.com/tracel-ai/burn/discussions/1413#discussioncomment-8687355

This is downsample layers data: image

Also, the key of the downsample layer has been remapped: https://github.com/zemelLeong/table-structure-recognition/blob/826e3cb16e884c6946a3bf39f92130107dac45d3/src/model_lore.rs#L153

laggui commented 7 months ago

Ok I'll take a look

zemelLeong commented 7 months ago

Sorry, maybe I made a mistake with the mapping for downsample, it should be modified as follows:

.with_key_remap("downsample.", "downsample.layers.")
.with_debug_print();

After the modification, I encountered this error and I currently do not have a solution.

thread 'main' panicked at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-tensor\src\tensor\data.rs:385:46:
range end index 4 out of range for slice of length 1
stack backtrace:
   0: std::panicking::begin_panic_handler
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library\std\src\panicking.rs:645
   1: core::panicking::panic_fmt
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library\core\src\panicking.rs:72
   2: core::slice::index::slice_end_index_len_fail_rt
laggui commented 7 months ago

Ahhh ok so it seems the record was loaded properly now, correct?

Mind posting the full trace with RUST_BACKTRACE=1?

zemelLeong commented 7 months ago

It seems to still be an issue with key mapping, I'll take a look at it myself first.

zemelLeong commented 7 months ago

I debugged it and found that the data belongs to DownSample BatchNorm, but the call stack went through Conv2d. Therefore, I speculate whether it is because both contain weight, so the deserialization operation gave the data that originally belonged to BatchNorm to Conv2d. Below are the error logs and the definition of DownSample.

#[derive(Module, Debug)]
pub enum Layers<B: Backend> {
    Conv2d(Conv2d<B>),
    BatchNorm(BatchNorm<B, 2>),
    ConvTranspose2d(ConvTranspose2d<B>),
}
#[derive(Module, Debug)]
pub struct DownSample<B: Backend> {
    layers: Vec<Layers<B>>,
}
struct DownSampleConfig {
    inplanes: usize,
    planes: usize,
    stride: usize,
}
impl DownSampleConfig {
    pub fn new(inplanes: usize, planes: usize, stride: usize) -> Self {
        Self {
            inplanes,
            planes,
            stride,
        }
    }
    pub fn build<B: Backend>(self, device: &B::Device) -> DownSample<B> {
        let layers = vec![
            Layers::Conv2d(
                Conv2dConfig::new([self.inplanes, self.planes], [1, 1])
                    .with_stride([self.stride, self.stride])
                    .with_bias(false)
                    .with_padding(PaddingConfig2d::Explicit(0, 0))
                    .init(device),
            ),
            Layers::BatchNorm(
                BatchNormConfig::new(self.planes)
                    .with_momentum(BN_MOMENTUM)
                    .init(device),
            ),
        ];
        DownSample { layers }
    }
}

image

thread 'main' panicked at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-tensor\src\tensor\data.rs:385:46:
range end index 4 out of range for slice of length 1
stack backtrace:
   0: std::panicking::begin_panic_handler
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library\std\src\panicking.rs:645
   1: core::panicking::panic_fmt
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library\core\src\panicking.rs:72
   2: core::slice::index::slice_end_index_len_fail_rt
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library\core\src\slice\index.rs:76
   3: core::slice::index::slice_end_index_len_fail
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112/library\core\src\slice\index.rs:68
   4: core::slice::index::impl$4::index<usize>
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\slice\index.rs:408
   5: core::slice::index::impl$5::index<usize>
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\slice\index.rs:455
   6: alloc::vec::impl$12::index<usize,core::ops::range::RangeTo<usize>,alloc::alloc::Global>
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\alloc\src\vec\mod.rs:2732
   7: burn_tensor::tensor::data::impl$13::from<f64,4>
             at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-tensor\src\tensor\data.rs:385
   8: core::convert::impl$3::into<burn_tensor::tensor::data::DataSerialize<f64>,burn_tensor::tensor::data::Data<f64,4> >
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\convert\mod.rs:757
   9: burn_tensor::tensor::api::base::Tensor<burn_ndarray::backend::NdArray<f64>,4,burn_tensor::tensor::api::kind::Float>::from_data<burn_ndarray::backend::NdArray<f64>,4,burn_tensor::tensor::api::kind::Float,burn_tensor::tensor::data::DataSerialize<f64> >
             at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-tensor\src\tensor\api\base.rs:563
  10: burn_core::record::tensor::impl$6::from_item<burn_ndarray::backend::NdArray<f64>,4,burn_core::record::settings::FullPrecisionSettings>
             at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-core\src\record\tensor.rs:100
  11: burn_core::record::primitive::impl$6::from_item<burn_ndarray::backend::NdArray<f64>,4,burn_core::record::settings::FullPrecisionSettings>
             at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-core\src\record\primitive.rs:185
  12: burn_core::nn::conv::conv2d::impl$14::from_item<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>
             at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-core\src\nn\conv\conv2d.rs:51
  13: table_structure_recognition::lore_detector::impl$33::from_item<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>
             at .\src\lore_detector.rs:88
  14: burn_core::record::primitive::impl$1::from_item::closure$0<enum2$<table_structure_recognition::lore_detector::LayersRecord<burn_ndarray::backend::NdArray<f64> > >,burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>
             at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-core\src\record\primitive.rs:45
  15: core::iter::adapters::map::map_fold::closure$0<enum2$<table_structure_recognition::lore_detector::LayersRecordItem<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings> >,enum2$<table_structure_recognition::lore_detector:
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\iter\adapters\map.rs:84
  16: core::iter::traits::iterator::Iterator::fold<alloc::vec::into_iter::IntoIter<enum2$<table_structure_recognition::lore_detector::LayersRecordItem<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings> >,alloc::alloc::Global
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\iter\traits\iterator.rs:2640
  17: core::iter::adapters::map::impl$2::fold<enum2$<table_structure_recognition::lore_detector::LayersRecord<burn_ndarray::backend::NdArray<f64> > >,alloc::vec::into_iter::IntoIter<enum2$<table_structure_recognition::lore_detector::LayersRecordItem<burn_ndarra
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\iter\adapters\map.rs:124
  18: core::iter::traits::iterator::Iterator::for_each<core::iter::adapters::map::Map<alloc::vec::into_iter::IntoIter<enum2$<table_structure_recognition::lore_detector::LayersRecordItem<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecis
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\iter\traits\iterator.rs:858
  19: alloc::vec::Vec<enum2$<table_structure_recognition::lore_detector::LayersRecord<burn_ndarray::backend::NdArray<f64> > >,alloc::alloc::Global>::extend_trusted<enum2$<table_structure_recognition::lore_detector::LayersRecord<burn_ndarray::backend::NdArray<f6
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\alloc\src\vec\mod.rs:2885
  20: alloc::vec::spec_extend::impl$1::spec_extend<enum2$<table_structure_recognition::lore_detector::LayersRecord<burn_ndarray::backend::NdArray<f64> > >,core::iter::adapters::map::Map<alloc::vec::into_iter::IntoIter<enum2$<table_structure_recognition::lore_de
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\alloc\src\vec\spec_extend.rs:26
  21: alloc::vec::spec_from_iter_nested::impl$1::from_iter<enum2$<table_structure_recognition::lore_detector::LayersRecord<burn_ndarray::backend::NdArray<f64> > >,core::iter::adapters::map::Map<alloc::vec::into_iter::IntoIter<enum2$<table_structure_recognition:
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\alloc\src\vec\spec_from_iter_nested.rs:62
  22: alloc::vec::in_place_collect::impl$1::from_iter<enum2$<table_structure_recognition::lore_detector::LayersRecord<burn_ndarray::backend::NdArray<f64> > >,core::iter::adapters::map::Map<alloc::vec::into_iter::IntoIter<enum2$<table_structure_recognition::lore
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\alloc\src\vec\in_place_collect.rs:167
  23: alloc::vec::impl$14::from_iter<enum2$<table_structure_recognition::lore_detector::LayersRecord<burn_ndarray::backend::NdArray<f64> > >,core::iter::adapters::map::Map<alloc::vec::into_iter::IntoIter<enum2$<table_structure_recognition::lore_detector::Layers
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\alloc\src\vec\mod.rs:2753
  24: core::iter::traits::iterator::Iterator::collect<core::iter::adapters::map::Map<alloc::vec::into_iter::IntoIter<enum2$<table_structure_recognition::lore_detector::LayersRecordItem<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisi
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\iter\traits\iterator.rs:2054
  25: burn_core::record::primitive::impl$1::from_item<enum2$<table_structure_recognition::lore_detector::LayersRecord<burn_ndarray::backend::NdArray<f64> > >,burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>
             at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-core\src\record\primitive.rs:44
  26: table_structure_recognition::lore_detector::impl$39::from_item<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>
             at .\src\lore_detector.rs:105
  27: burn_core::record::primitive::impl$2::from_item::closure$0<table_structure_recognition::lore_detector::DownSampleRecord<burn_ndarray::backend::NdArray<f64> >,burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>
             at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-core\src\record\primitive.rs:62
  28: enum2$<core::option::Option<table_structure_recognition::lore_detector::DownSampleRecordItem<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings> > >::map<table_structure_recognition::lore_detector::DownSampleRecordItem<
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\option.rs:1072
  29: burn_core::record::primitive::impl$2::from_item<table_structure_recognition::lore_detector::DownSampleRecord<burn_ndarray::backend::NdArray<f64> >,burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>
             at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-core\src\record\primitive.rs:62
  30: table_structure_recognition::lore_detector::impl$51::from_item<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>
             at .\src\lore_detector.rs:298
  31: burn_core::record::primitive::impl$1::from_item::closure$0<table_structure_recognition::lore_detector::BasicBlockRecord<burn_ndarray::backend::NdArray<f64> >,burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>
             at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-core\src\record\primitive.rs:45
  32: core::iter::adapters::map::map_fold::closure$0<table_structure_recognition::lore_detector::BasicBlockRecordItem<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>,table_structure_recognition::lore_detector::BasicBlockR
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\iter\adapters\map.rs:84
  33: core::iter::traits::iterator::Iterator::fold<alloc::vec::into_iter::IntoIter<table_structure_recognition::lore_detector::BasicBlockRecordItem<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>,alloc::alloc::Global>,tup
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\iter\traits\iterator.rs:2640
  34: core::iter::adapters::map::impl$2::fold<table_structure_recognition::lore_detector::BasicBlockRecord<burn_ndarray::backend::NdArray<f64> >,alloc::vec::into_iter::IntoIter<table_structure_recognition::lore_detector::BasicBlockRecordItem<burn_ndarray::backe
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\iter\adapters\map.rs:124
  35: core::iter::traits::iterator::Iterator::for_each<core::iter::adapters::map::Map<alloc::vec::into_iter::IntoIter<table_structure_recognition::lore_detector::BasicBlockRecordItem<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecision
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\iter\traits\iterator.rs:858
  36: alloc::vec::Vec<table_structure_recognition::lore_detector::BasicBlockRecord<burn_ndarray::backend::NdArray<f64> >,alloc::alloc::Global>::extend_trusted<table_structure_recognition::lore_detector::BasicBlockRecord<burn_ndarray::backend::NdArray<f64> >,all
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\alloc\src\vec\mod.rs:2885
  37: alloc::vec::spec_extend::impl$1::spec_extend<table_structure_recognition::lore_detector::BasicBlockRecord<burn_ndarray::backend::NdArray<f64> >,core::iter::adapters::map::Map<alloc::vec::into_iter::IntoIter<table_structure_recognition::lore_detector::Basi
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\alloc\src\vec\spec_extend.rs:26
  38: alloc::vec::spec_from_iter_nested::impl$1::from_iter<table_structure_recognition::lore_detector::BasicBlockRecord<burn_ndarray::backend::NdArray<f64> >,core::iter::adapters::map::Map<alloc::vec::into_iter::IntoIter<table_structure_recognition::lore_detect
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\alloc\src\vec\spec_from_iter_nested.rs:62
  39: alloc::vec::in_place_collect::impl$1::from_iter<table_structure_recognition::lore_detector::BasicBlockRecord<burn_ndarray::backend::NdArray<f64> >,core::iter::adapters::map::Map<alloc::vec::into_iter::IntoIter<table_structure_recognition::lore_detector::B
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\alloc\src\vec\in_place_collect.rs:167
  40: alloc::vec::impl$14::from_iter<table_structure_recognition::lore_detector::BasicBlockRecord<burn_ndarray::backend::NdArray<f64> >,core::iter::adapters::map::Map<alloc::vec::into_iter::IntoIter<table_structure_recognition::lore_detector::BasicBlockRecordIt
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\alloc\src\vec\mod.rs:2753
  41: core::iter::traits::iterator::Iterator::collect<core::iter::adapters::map::Map<alloc::vec::into_iter::IntoIter<table_structure_recognition::lore_detector::BasicBlockRecordItem<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionS
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\iter\traits\iterator.rs:2054
  42: burn_core::record::primitive::impl$1::from_item<table_structure_recognition::lore_detector::BasicBlockRecord<burn_ndarray::backend::NdArray<f64> >,burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>
             at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-core\src\record\primitive.rs:44
  43: table_structure_recognition::lore_detector::impl$15::from_item<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>
             at .\src\lore_detector.rs:12
  44: table_structure_recognition::lore_detector::impl$57::from_item<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>
             at .\src\lore_detector.rs:420
  45: table_structure_recognition::model_lore::impl$5::from_item<burn_ndarray::backend::NdArray<f64>,burn_core::record::settings::FullPrecisionSettings>
             at .\src\model_lore.rs:13
  46: burn_import::pytorch::recorder::impl$0::load<burn_core::record::settings::FullPrecisionSettings,burn_ndarray::backend::NdArray<f64>,table_structure_recognition::model_lore::LoreModelRecord<burn_ndarray::backend::NdArray<f64> > >
             at C:\Users\11989\.cargo\git\checkouts\burn-178c6829f420dae1\2a721a9\crates\burn-import\src\pytorch\recorder.rs:53
  47: table_structure_recognition::model_lore::LoreModel<burn_ndarray::backend::NdArray<f64> >::new_with<burn_ndarray::backend::NdArray<f64> >
             at .\src\model_lore.rs:157
  48: table_structure_recognition::main
             at .\src\main.rs:13
  49: core::ops::function::FnOnce::call_once<void (*)(),tuple$<> >
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\ops\function.rs:250
  50: core::hint::black_box
             at /rustc/82e1608dfa6e0b5569232559e3d385fea5a93112\library\core\src\hint.rs:286
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.
zemelLeong commented 7 months ago

It seems that this is the reason. The solution I can think of right now is to change all model structures similar to Sequential to Struct, like the following:

#[derive(Module, Debug)]
pub struct DownSample<B: Backend> {
    conv: Conv2d<B>,
    bn: BatchNorm<B, 2>,
}

struct DownSampleConfig {
    inplanes: usize,
    planes: usize,
    stride: usize,
}

impl DownSampleConfig {
    pub fn new(inplanes: usize, planes: usize, stride: usize) -> Self {
        Self {
            inplanes,
            planes,
            stride,
        }
    }

    pub fn build<B: Backend>(self, device: &B::Device) -> DownSample<B> {
        DownSample {
            conv: Conv2dConfig::new([self.inplanes, self.planes], [1, 1])
                .with_stride([self.stride, self.stride])
                .with_bias(false)
                .with_padding(PaddingConfig2d::Explicit(0, 0))
                .init(device),
            bn: BatchNormConfig::new(self.planes)
                .with_momentum(BN_MOMENTUM)
                .init(device),
        }
    }
}
wcshds commented 7 months ago

I previously asked a similar question on Discord, and @antimora responded:

In the book I have described that PyTorch does not export the enum variant name. So the best thing we can do is to match structure (attribute names and types). If you have the exact same structure for variants, it will match first one - limitation you should be aware of. From your example it appears LayerNorm and Conv2 are clashing. I would recommend renaming if the prior context name is identifiable (if you can use regex to remap)

It seems that now PytorchFileRecorder is not very flexible when importing weights. My workaround is to export the Burn model to a json file using PrettyJsonFileRecorder , then modify the json file manually according to the pretrained weights in Python, and finally re-import the modified json file back into Burn. I can import weights without modifying the structure of the Burn model in this way.

laggui commented 7 months ago

I previously asked a similar question on Discord, and @antimora responded:

In the book I have described that PyTorch does not export the enum variant name. So the best thing we can do is to match structure (attribute names and types). If you have the exact same structure for variants, it will match first one - limitation you should be aware of. From your example it appears LayerNorm and Conv2 are clashing. I would recommend renaming if the prior context name is identifiable (if you can use regex to remap)

It seems that now PytorchFileRecorder is not very flexible when importing weights. My workaround is to export the Burn model to a json file using PrettyJsonFileRecorder , then modify the json file manually according to the pretrained weights in Python, and finally re-import the modified json file back into Burn. I can import weights without modifying the structure of the Burn model in this way.

That's weird, I've been able to successfully import the weights for ResNet and YOLOX which both contain enums. Haven't had any issues once the keys remapping was completed.

But yes, as noted in your reply enums are simply variants that "structurally" will be at the same place in your architecture. So the mapping is based on the contents (untagged), which means that modules with the same parameters might clash.

laggui commented 7 months ago

It seems that this is the reason. The solution I can think of right now is to change all model structures similar to Sequential to Struct, like the following:

It is quite possible that sequential isn't helping here. I am personally not a big fan of using sequential everywhere because it's just too flexible 😅 I'd rather be more explicit

antimora commented 7 months ago

I haven't dived too deeply into your model but I'd like to bring up a couple of immediate suggestions:

  1. Use more Regex to remap items in one go (instead of one by one mapping) and review your changes with_debug_print() option.
  2. The record automatically handles Non-contiguous indices in the source model. So you do not need to remap 2 => 1, 4 => 2 etc.
  3. To get around problems with non-unique enum variant structures, try using name prefixes to separate into standalone modules for straightforward mapping.