fastai / fastai2

Temporary home for fastai v2 while it's being developed
https://dev.fast.ai
Apache License 2.0
645 stars 235 forks source link

latest fastai2 v0.0.24 learner summary() RuntimeError: Input type (torch.cuda.FloatTensor) != weight type (torch.FloatTensor) #517

Closed bguan closed 4 years ago

bguan commented 4 years ago

Bug can be reproduced with latest fastai2 v 0.0.24 using this Colab notebook:

https://colab.research.google.com/gist/bguan/2e2fd854b12d26e3d14e05215c9b133b/fastai2-test.ipynb

This works in v 0.0.23 but breaks in 0.0.24.

With today's (2020 Aug 9th) github pull of latest fastai2 v0.0.24, learner summary() raise RuntimeError: "Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same"

To Reproduce Steps to reproduce the behavior:

  1. Make sure you have v 0.0.24
  2. Create a CNN Learner. See Colab Github Gist for example.
  3. Call learner.summary()

Expected behavior A nicely formatted summary of the model should be returned.

Error with full stack trace

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-bc39e9e85f86> in <module>
----> 1 learn.summary()

~/Projects/fastai2/fastai2/callback/hook.py in summary(self)
    185     "Print a summary of the model, optimizer and loss function."
    186     xb = self.dls.train.one_batch()[:self.dls.train.n_inp]
--> 187     res = module_summary(self.model, *xb)
    188     res += f"Optimizer used: {self.opt_func}\nLoss function: {self.loss_func}\n\n"
    189     if self.opt is not None:

~/Projects/fastai2/fastai2/callback/hook.py in module_summary(self, *xb)
    160 def module_summary(self, *xb):
    161     "Print a summary of `self` using `xb`"
--> 162     sample_inputs,infos = layer_info(self, *xb)
    163     n,bs = 64,find_bs(xb)
    164     inp_sz = _print_shapes(apply(lambda x:x.shape, xb), bs)

~/Projects/fastai2/fastai2/callback/hook.py in layer_info(model, *xb)
    149     layers = [m for m in flatten_model(model)]
    150     with Hooks(layers, _track) as h:
--> 151         _ = model.eval()(*apply(lambda o:o[:1], xb))
    152         return xb,h.stored
    153 

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py in forward(self, input)
    115     def forward(self, input):
    116         for module in self:
--> 117             input = module(input)
    118         return input
    119 

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py in forward(self, input)
    115     def forward(self, input):
    116         for module in self:
--> 117             input = module(input)
    118         return input
    119 

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py in forward(self, input)
    417 
    418     def forward(self, input: Tensor) -> Tensor:
--> 419         return self._conv_forward(input, self.weight)
    420 
    421 class Conv3d(_ConvNd):

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
    413                             weight, self.bias, self.stride,
    414                             _pair(0), self.dilation, self.groups)
--> 415         return F.conv2d(input, weight, self.bias, self.stride,
    416                         self.padding, self.dilation, self.groups)
    417 

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

Additional context

Learner training seems to work fine despite this issue.

Issue can be sidestepped by explicitly pushing the model to CUDA. e.g.

learn.model.cuda()

learn.summary()

Although this wasn't necessary in earlier version.

Ideas for fix

jph00 commented 4 years ago

Thanks for the helpful bug report. We call forward so we can find out the size of the activations at each stage of the model.

This is now fixed in master. Will do a new release soon.