coreylowman / dfdx

Deep learning in Rust, with shape checked tensors and neural networks
Other
1.71k stars 99 forks source link

Unclear how to handle error type in `dfdx::nn::LoadFromNpz::load` #900

Open hovinen opened 9 months ago

hovinen commented 9 months ago

The return type of dfdx::nn::LoadFromNpz::load is Result<(), NpzError>. NpzError is inside a private module in dfdx, so it's not clear how one can use this method in downstream code:

error[E0603]: module `numpy` is private
   --> src/abstract_model.rs:4:25
    |
4   |     prelude::{........, numpy::NpzError, ...},
    |                         ^^^^^  -------- enum `NpzError` is not publicly re-exported
    |                         |
    |                         private module
    |
note: the module `numpy` is defined here
   --> /.../dfdx-0.13.0/src/lib.rs:202:13
    |
202 |     pub use crate::tensor::*;
    |             ^^^^^^^^^^^^^

Normally, I would map it to std::io::Error, but the required From implementation is missing:

error[E0277]: `?` couldn't convert the error to `std::io::Error`
   --> src/abstract_model.rs:268:25
    |
268 |         model.load(path)?;
    |                         ^ the trait `From<tensor::numpy::NpzError>` is not implemented for `std::io::Error`
    |
    = note: the question mark operation (`?`) implicitly performs a conversion on the error value using the `From` trait
    = help: the following other types implement trait `From<T>`:
              <std::io::Error as From<rand::Error>>
              <std::io::Error as From<getrandom::error::Error>>
              <std::io::Error as From<zip::result::ZipError>>
              <std::io::Error as From<NulError>>
              <std::io::Error as From<IntoInnerError<W>>>
              <std::io::Error as From<ErrorKind>>
    = note: required for `Result<NeuralNetwork<Model, N_FEATURES, N_ACTIONS>, std::io::Error>` to implement `FromResidual<Result<Infallible, tensor::numpy::NpzError>>`

The documentation also doesn't show how one passes an error up to the caller. The example doesn't indicate what return type is being used.

How does one use the load() method in downstream code?