automl / TabPFN

Official implementation of the TabPFN paper (https://arxiv.org/abs/2207.01848) and the tabpfn package.
http://priorlabs.ai
Apache License 2.0
1.22k stars 109 forks source link

Synthetic Data Generator Issue #53

Closed moonrabbit12 closed 1 year ago

moonrabbit12 commented 1 year ago

I'm trying to play around with the synthetic data generator, but I'm running into an issue. I tried to run train.py but it gives me the following error:

Traceback (most recent call last): File "train.py", line 382, in <module> y_encoder_generator=y_encoder_generator, pos_encoder_generator=pos_encoder_generator, **args.__dict__) File "train.py", line 216, in train train_epoch() File "train.py", line 135, in train_epoch for batch, (data, targets, single_eval_pos) in enumerate(dl): File "TabPFN/tabpfn/priors/utils.py", line 46, in <genexpr> return iter(self.gbm(**self.get_batch_kwargs, epoch=self.epoch_count - 1, model=self.model) for _ in range(self.num_steps)) File "/TabPFN/tabpfn/priors/utils.py", line 33, in gbm batch = get_batch_method_(*args, **kwargs) File "anaconda3/envs/tabpfn/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context return func(*args, **kwargs) TypeError: get_batch() missing 1 required positional argument: 'get_batch'

When I tried to fix this by passing in the get_batch method, I get a recursion error that states the program reached maximum depth. Can anyone help me here?

noahho commented 1 year ago

The easiest way to get started with data generation is with this notebook: https://github.com/automl/TabPFN/blob/main/tabpfn/PriorFittingCustomPrior.ipynb

here there is already some plotting code and correct way to call main. Does this answer your question?

moonrabbit12 commented 1 year ago

Yes, thank you!