layer6ai-labs / TabDPT

TabDPT: Scaling Tabular Foundation Models
https://arxiv.org/abs/2410.18164
11 stars 3 forks source link

[interface] Fix a regression numpy/torch error and default value #2

Closed LennartPurucker closed 2 days ago

LennartPurucker commented 1 week ago

Heyho,

This PR includes two small changes to improve the regression interface.

First, I added a default value for context_size, the same as for classification.

Second, I fixed an error in the case where the context was larger than the training data. Previously, this would cause an error unsupported operand type(s) for *: 'numpy.ndarray' and 'Tensor' as the predictions were transformed to numpy before being multiplied with a torch tensor.