charles-r-earp / autograph

A machine learning library for Rust.
Apache License 2.0
315 stars 17 forks source link

Added Sequential layer for reduced boilerplate. #28

Closed charles-r-earp closed 4 years ago

charles-r-earp commented 4 years ago

Addresses #21. Along with branches flatten_layer and forward_requires_layer, you could then do something like this:

fn lenet5(device: &Device) -> impl Forward<Ix4, OutputDim=Ix2> {
    Sequential::builder()
        .layer(
            Conv2d::builder()
                .device(&device)
                .inputs(1)
                .outputs(6)
                .kernel(5)
                .build();
        )
        .layer(Relu::default())
        .layer(
            MaxPool2d::builder()
                .args(
                    Pool2dArgs::default()
                        .kernel(2)
                        .strides(2)
                )
                .build()
        )
        .layer(
            Conv2d::builder()
                .device(&device)
                .inputs(6)
                .outputs(16)
                .kernel(5)
                .build()
            )
        )
        .layer(Relu::default())
        .layer(
            MaxPool2d::builder()
                .args(
                    Pool2dArgs::default()
                        .kernel(2)
                        .strides(2)
                )
                .build()
        )
        .layer(Flatten::default())
        .layer(
            Dense::builder()
                .device(&device)
                .inputs(256)
                .outputs(120)
                .build()
        )
        .layer(Relu::default())
        .layer(
            Dense::builder()
            .device(&device)
            .inputs(120)
            .outputs(84)
            .build()
        )
        .layer(Relu::default())
        .layer(
            Dense::builder()
                .device(&device)
                .inputs(84)
                .outputs(10)
                .bias()
                .build()
        )
        .build()
}