tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
9.03k stars 448 forks source link

Add "load pytorch tensor" section into the burn book #2316

Open med1844 opened 1 month ago

med1844 commented 1 month ago

Issue based on discussion #2315, @antimora

To my best knowledge, here's how to load a tensor:

  1. In python: Ensure you wrap the tensor with dict before save, e.g.

    torch.save({"some_key": tensor}, "path/to/tensor.pt")
  2. In rust:

    #[derive(Module, Debug)]
    struct FloatTensor<B: Backend, const D: usize> {
        some_key: Param<Tensor<B, D>>,
    }
    
    fn main() {
        type B = NdArray;
        let device = Default::default();
        let tensor: FloatTensorRecord<B, 3> =
            PyTorchFileRecorder::<FullPrecisionSettings>::new()
                .load("path/to/tensor.pt".into(), &device)
                .unwrap();
        let tensor = tensor.value.val();
    }
csking101 commented 1 month ago

Hi, could I take up this issue?

antimora commented 1 month ago

Hi, could I take up this issue?

Yes! Please go ahead. We would appreciate your contribution. Let me know if you need more info.