automl / TabPFN

Official implementation of the TabPFN paper (https://arxiv.org/abs/2207.01848) and the tabpfn package.
http://priorlabs.ai
Apache License 2.0
1.22k stars 109 forks source link

Update transformer_prediction_interface.py #70

Closed liuquangao closed 11 months ago

liuquangao commented 11 months ago

Hello, and first of all, thank you so much for your excellent work! I've observed that while we have 'models_diff/prior_diff_real_checkpoint_n_0_epoch_42.cpkt' available locally, when executing TabPFN = TabPFNClassifier(device='cpu', N_ensemble_configurations=32), the code seems to first look for a non-existent file named 'prior_diff_real_checkpoint_n_0_epoch_100.cpkt', which then triggers a download. Unfortunately, this download does not succeed, perhaps due to my location in China.

In light of this, I have taken the liberty of adjusting the 'e' parameter in the load_model_workflow() function to 42, in hopes of bypassing the need for downloading.