rasbt / machine-learning-book

Code Repository for Machine Learning with PyTorch and Scikit-Learn
https://sebastianraschka.com/books/#machine-learning-with-pytorch-and-scikit-learn
MIT License
3.61k stars 1.3k forks source link

Chapter 2 Pary 2 pred.long() throws error: "log_softmax_lastdim_kernel_impl" not implemented for 'Long' #144

Closed TyTorch closed 1 year ago

TyTorch commented 1 year ago

The model created to perform regression on the iris dataset uses a softmax function as the activation function of its final layer. I am not able to call the long() function on its output. It does work when I uses float instead.

The following is the error log I get when copying and pasting the code into Google Collab

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-5-4b6724b323fe>](https://localhost:8080/#) in <cell line: 57>()
     59     for x_batch, y_batch in train_dl:
     60         pred = model(x_batch)
---> 61         loss = loss_fn(pred.long(), y_batch)
     62         loss.backward()
     63         optimizer.step()

2 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3027     if size_average is not None or reduce is not None:
   3028         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3029     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
   3030 
   3031 

RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Long'
rasbt commented 1 year ago

Thanks for the note. I am not sure why there is a .long() call. Like you said, it should be a float.

rasbt commented 1 year ago

Just fixed this.

For some reason, the ch12.py script file had a

    loss = loss_fn(pred.long(), y_batch)

instead of

    loss = loss_fn(pred, y_batch.long())

The ch12.ipynb file was ok.