deep-symbolic-mathematics / Multimodal-Symbolic-Regression

[ICLR 2024 Spotlight] SNIP on Symbolic Regression: Deep Symbolic Regression with Multimodal Pretraining
MIT License
9 stars 0 forks source link

Transformer function generate_from_latent() uses wrong datatype on row 1306 #1

Open fabienmorgan opened 3 months ago

fabienmorgan commented 3 months ago

While trying to run Latent space optimization(LSO_eval.py) I got an error:

Traceback (most recent call last):
...
File ".../Multimodal-Symbolic-Regression/symbolicregression/model/transformer.py", line 1306, in generate_from_latent
    generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index)
RuntimeError: masked_fill only supports boolean masks, but got dtype Byte 

The error could be resolved in changing the datatype on Line 1306 from byte to bool. Current version generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index)

My version generated[-1].masked_fill_(unfinished_sents.bool(), self.eos_index)

The model was downloaded from the readme and the benchmarks were cloned with LFS from GitHub in the case of the PMLB benchmark and downloaded from the linked site in the case of the Feynman benchmark.

I hope that you can look into the bug and reply if my correction is correct and doesn't only resolve the symptoms.

Configurations:

python LSO_eval.py --reload_model ./weights/snip-e2e-sr.pth \
                    --eval_lso_on_pmlb True \
                    --pmlb_data_type strogatz \
                    --target_noise 0.0 \
                    --max_input_points 200 \
                    --lso_optimizer gwo \
                    --lso_pop_size 50 \
                    --lso_max_iteration 80 \
                    --lso_stop_r2 0.99 \
                    --beam_size 2
RohanPhadnis commented 2 weeks ago

That's right. I got the same issue and made exactly the same modification to fix it. Thanks @fabienmorgan

I'm not a contributor to this repository so I'm not entirely sure if this solution hand-waves the actual problem, but it seemed to work for me too.