AlliedToasters / dfencoder

BSD 3-Clause "New" or "Revised" License
129 stars 49 forks source link

Demo, model.fit, RuntimeError: scatter(): Expected dtype int64 for index #15

Open sergeyburnos opened 2 years ago

sergeyburnos commented 2 years ago

Hi Michael,

I started using your great lib dfencoder. I am having issues with running demo. As you advised, created a virtual environment and running the demo code. I tried different 3.6 versions - 3.6.2, 3.6.8 and 3.6.10. What could be an issue?

I get the error already on line 5 model.fit(X_train, epochs=10, val=X_val)

RuntimeError Traceback (most recent call last)

in ----> 1 model.fit(X_train, epochs=10, val=X_val) C:\ProgramData\Anaconda3\lib\site-packages\dfencoder\autoencoder.py in fit(self, df, epochs, val) 674 msg = "Validating during training.\n" 675 msg += "Computing baseline performance..." --> 676 baseline = self.compute_baseline_performance(val_in, val_df) 677 if self.verbose: 678 print(msg) C:\ProgramData\Anaconda3\lib\site-packages\dfencoder\autoencoder.py in compute_baseline_performance(self, in_, out_) 648 feature = list(self.categorical_fts.items())[i][1] 649 dim = len(feature['cats']) + 1 --> 650 pred = ohe(cd, dim, device=self.device) * 5 651 codes_pred.append(pred) 652 mse_loss, bce_loss, cce_loss, net_loss = self.compute_loss( C:\ProgramData\Anaconda3\lib\site-packages\dfencoder\autoencoder.py in ohe(input_vector, dim, device) 34 35 y_onehot.zero_() ---> 36 y_onehot.scatter_(1, y, 1) 37 38 return y_onehot RuntimeError: scatter(): Expected dtype int64 for index