krshrimali / Digit-Recognition-MNIST-SVHN-PyTorch-CPP

Implementing CNN for Digit Recognition (MNIST and SVHN dataset) using PyTorch C++ API
https://krshrimali.github.io/PyTorch-C++-API/
MIT License
24 stars 10 forks source link

dataloader to script module #5

Open kelseyjd opened 9 months ago

kelseyjd commented 9 months ago

Hi and thank you for this example code. I am trying to solve the same MNIST problem by using a PyTorch model that has been converted with jit.script(). However, jit modules take input at torch::Tensor not at::Tensor and I haven't found a way to do this type conversion. Is there a way to modify the dataloader code in order to output a torch::Tensor instead of at::Tensor?

auto data_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>( std::move(torch::data::datasets::MNIST("/flash/kelseyd/mnist/MNIST_ORG").map(torch::data::transforms::Normalize<>(0.13707, 0.3081)).map(torch::data::transforms::Stack<>())), 64);