silvandeleemput / memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks
MIT License
251 stars 26 forks source link

Give a simple example of classification? #68

Open ghost opened 3 years ago

ghost commented 3 years ago

Hi @silvandeleemput,

Thanks for your code! Could you give a simple example of how to do classification using memcnn?

silvandeleemput commented 3 years ago

Hi @cubicgate, thanks for your interest in MemCNN. Your question is very broad, do you want it for a specific application? Are you interested in inference, training, or both?

You can use the RevNet implementations as classification examples.

Below is a simple example for instantiating a RevNet36 model and performing prediction on some random input data:


import torch
from memcnn.models.resnet import ResNet, BasicBlock, RevBasicBlock

revnet36 = ResNet(
  block=RevBasicBlock,
  layers=[3, 3, 3],
  channels_per_layer=[32, 32, 64, 112],
  strides=[1, 1, 2, 2],
  init_max_pool=False,
  init_kernel_size=3,
  batch_norm_fix=False
)

# use the revnet36 model for training and validation as you would normally do

revnet36.eval()
with torch.no_grad():
    x = torch.rand(2, 3, 32, 32)
    y = revnet36.forward(x)
ghost commented 3 years ago

Thanks @silvandeleemput for this example! It is very helpful!