potocpav / npy-rs

NumPy file format (de-)serialization in Rust
30 stars 7 forks source link

How to serialize multi-dimensional array? #5

Closed 0b01 closed 6 years ago

0b01 commented 6 years ago

So I have a 3D array of shape [batch_size, time_step, input_dim]. What is the easiest way to serialize this?

This is what I currently have:

#[macro_use] extern crate npy_derive;
extern crate npy;
extern crate byteorder;

use std::io::{Write, Read};
use npy::{DType,Serializable};
use byteorder::{WriteBytesExt, LittleEndian};
use byteorder::ByteOrder;

/// [batch_size, time_steps, input_dimensions]
/// 
#[derive(NpyRecord, Debug, PartialEq, Clone)]
struct Input {
    a: i32,
    b: f32,
}

// steps

#[derive(NpyRecord)]
struct Steps {
    steps: Steps_100
}

#[derive(Debug)]
struct Steps_100(Vec<Input>);

impl Serializable for Steps_100 {
    fn dtype() -> DType {
        DType { ty: "<i4,f4", shape: vec![2] }
    }

    fn n_bytes() -> usize { 100 * 2 * 4 }

    fn read(buf: &[u8]) -> Self {
        let mut ret = Steps_100(vec![]);
        for _ in 0..100 {
            ret.0.push( Input {
                a: LittleEndian::read_i32(buf),
                b: LittleEndian::read_f32(buf)
            });
        }
        ret
    }

    fn write<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
        for i in 0..100 {
            writer.write_i32::<LittleEndian>(self.0[i].a)?;
            writer.write_f32::<LittleEndian>(self.0[i].b)?;
        }
        Ok(())
    }
}

fn main() {
    let pi = std::f32::consts::PI;
    let mut batches = vec![];
    for i in 0..100i32 {
        let mut time_steps = vec![];
        for j in 0..360i32 {
            time_steps.push(Input { a: j, b: (j as f32 * pi / 180.0).sin() });
        }
        batches.push(Steps{ steps: Steps_100(time_steps) });
    }

    npy::to_file("roundtrip.npy", batches).unwrap();

}

fn read() {
    // let mut buf = vec![];
    // std::fs::File::open("roundtrip.npy").unwrap()
    //     .read_to_end(&mut buf).unwrap();

    // for (i, arr) in npy::NpyData::from_bytes(&buf).unwrap().into_iter().enumerate() {
    //     assert_eq!(Input { a: i as i32, b: (i as f32 * pi / 180.0).sin() }, arr);
    // }
}

However, this outputs:

[([( 0,  0.        ), ( 1,  0.01745241)],)
 ([( 2,  0.0348995 ), ( 3,  0.05233596)],)
 ([( 4,  0.06975647), ( 5,  0.08715574)],)
 ([( 6,  0.10452846), ( 7,  0.12186935)],)
 ([( 8,  0.13917311), ( 9,  0.15643448)],)
 ([(10,  0.17364819), (11,  0.19080901)],)
 ([(12,  0.2079117 ), (13,  0.22495106)],)
 ([(14,  0.24192192), (15,  0.25881904)],)
 ([(16,  0.27563736), (17,  0.29237172)],)
 ([(18,  0.309017  ), (19,  0.32556814)],)
 ([(20,  0.34202015), (21,  0.35836795)],)
 ([(22,  0.37460661), (23,  0.39073113)],)
 ([(24,  0.40673664), (25,  0.42261827)],)
 ([(26,  0.43837118), (27,  0.45399052)],)
 ([(28,  0.4694716 ), (29,  0.48480961)],)
 ([(30,  0.5       ), (31,  0.51503813)],)
 ([(32,  0.52991927), (33,  0.54463905)],)
 ([(34,  0.55919296), (35,  0.57357651)],)
 ([(36,  0.58778524), (37,  0.60181504)],)
 ([(38,  0.6156615 ), (39,  0.62932038)],)
 ([(40,  0.64278764), (41,  0.65605903)],)
 ([(42,  0.66913062), (43,  0.68199837)],)
 ([(44,  0.6946584 ), (45,  0.70710677)],)
 ([(46,  0.71933979), (47,  0.73135376)],)
 ([(48,  0.74314487), (49,  0.7547096 )],)
 ([(50,  0.76604444), (51,  0.77714604)],)
 ([(52,  0.78801078), (53,  0.79863554)],)
 ([(54,  0.809017  ), (55,  0.81915206)],)
 ([(56,  0.82903761), (57,  0.83867061)],)
 ([(58,  0.84804809), (59,  0.8571673 )],)
 ([(60,  0.86602545), (61,  0.87461972)],)
 ([(62,  0.88294762), (63,  0.89100659)],)
 ([(64,  0.89879405), (65,  0.90630782)],)
 ([(66,  0.91354549), (67,  0.92050487)],)
 ([(68,  0.92718387), (69,  0.9335804 )],)
 ([(70,  0.93969268), (71,  0.94551855)],)
 ([(72,  0.95105654), (73,  0.95630479)],)
 ([(74,  0.96126169), (75,  0.96592587)],)
 ([(76,  0.97029573), (77,  0.97437006)],)
 ([(78,  0.97814763), (79,  0.98162717)],)
 ([(80,  0.98480779), (81,  0.98768836)],)
 ([(82,  0.99026805), (83,  0.99254614)],)
 ([(84,  0.99452192), (85,  0.99619472)],)
 ([(86,  0.99756408), (87,  0.99862951)],)
 ([(88,  0.99939084), (89,  0.99984771)],)
 ([(90,  1.        ), (91,  0.99984771)],)
 ([(92,  0.99939084), (93,  0.99862951)],)
 ([(94,  0.99756402), (95,  0.99619466)],)
 ([(96,  0.99452192), (97,  0.99254614)],)
 ([(98,  0.99026805), (99,  0.98768836)],)
 ([( 0,  0.        ), ( 1,  0.01745241)],)
 ([( 2,  0.0348995 ), ( 3,  0.05233596)],)
 ([( 4,  0.06975647), ( 5,  0.08715574)],)
 ([( 6,  0.10452846), ( 7,  0.12186935)],)
 ([( 8,  0.13917311), ( 9,  0.15643448)],)
 ([(10,  0.17364819), (11,  0.19080901)],)
 ([(12,  0.2079117 ), (13,  0.22495106)],)
 ([(14,  0.24192192), (15,  0.25881904)],)
 ([(16,  0.27563736), (17,  0.29237172)],)
 ([(18,  0.309017  ), (19,  0.32556814)],)
 ([(20,  0.34202015), (21,  0.35836795)],)
 ([(22,  0.37460661), (23,  0.39073113)],)
 ([(24,  0.40673664), (25,  0.42261827)],)
 ([(26,  0.43837118), (27,  0.45399052)],)
 ([(28,  0.4694716 ), (29,  0.48480961)],)
 ([(30,  0.5       ), (31,  0.51503813)],)
 ([(32,  0.52991927), (33,  0.54463905)],)
 ([(34,  0.55919296), (35,  0.57357651)],)
 ([(36,  0.58778524), (37,  0.60181504)],)
 ([(38,  0.6156615 ), (39,  0.62932038)],)
 ([(40,  0.64278764), (41,  0.65605903)],)
 ([(42,  0.66913062), (43,  0.68199837)],)
 ([(44,  0.6946584 ), (45,  0.70710677)],)
 ([(46,  0.71933979), (47,  0.73135376)],)
 ([(48,  0.74314487), (49,  0.7547096 )],)
 ([(50,  0.76604444), (51,  0.77714604)],)
 ([(52,  0.78801078), (53,  0.79863554)],)
 ([(54,  0.809017  ), (55,  0.81915206)],)
 ([(56,  0.82903761), (57,  0.83867061)],)
 ([(58,  0.84804809), (59,  0.8571673 )],)
 ([(60,  0.86602545), (61,  0.87461972)],)
 ([(62,  0.88294762), (63,  0.89100659)],)
 ([(64,  0.89879405), (65,  0.90630782)],)
 ([(66,  0.91354549), (67,  0.92050487)],)
 ([(68,  0.92718387), (69,  0.9335804 )],)
 ([(70,  0.93969268), (71,  0.94551855)],)
 ([(72,  0.95105654), (73,  0.95630479)],)
 ([(74,  0.96126169), (75,  0.96592587)],)
 ([(76,  0.97029573), (77,  0.97437006)],)
 ([(78,  0.97814763), (79,  0.98162717)],)
 ([(80,  0.98480779), (81,  0.98768836)],)
 ([(82,  0.99026805), (83,  0.99254614)],)
 ([(84,  0.99452192), (85,  0.99619472)],)
 ([(86,  0.99756408), (87,  0.99862951)],)
 ([(88,  0.99939084), (89,  0.99984771)],)
 ([(90,  1.        ), (91,  0.99984771)],)
 ([(92,  0.99939084), (93,  0.99862951)],)
 ([(94,  0.99756402), (95,  0.99619466)],)
 ([(96,  0.99452192), (97,  0.99254614)],)
 ([(98,  0.99026805), (99,  0.98768836)],)]
0b01 commented 6 years ago

I wrote a blog article on how to serialize to .npy format: http://rickyhan.com/jekyll/update/2017/11/16/preparing-training-set-with-rust-rayon-npy-format.html