worldbank / REaLTabFormer

A suite of auto-regressive and Seq2Seq (sequence-to-sequence) transformer models for tabular and relational synthetic data generation.
https://worldbank.github.io/REaLTabFormer/
MIT License
200 stars 23 forks source link

Error on mps device #78

Closed qzhu2017 closed 2 months ago

qzhu2017 commented 3 months ago

Hi,

I tried to run it on my Mac using the mps device (Apple’s Metal Performance Shaders for GPU acceleration). However, I got the following error message

dtype: float64
Sensitivity threshold: 0.01891369047619046 qt_max: 0.05
Map: 100%|███████████████████████████████████████████████████████| 1023/1023 [00:00<00:00, 1494.44 examples/s]
/Users/qzhu8/miniconda3/lib/python3.11/site-packages/transformers/training_args.py:1474: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
{'loss': 0.6446, 'grad_norm': 1.218327283859253, 'learning_rate': 4.8437500000000005e-05, 'epoch': 3.12}      
{'train_runtime': 207.8982, 'train_samples_per_second': 24.603, 'train_steps_per_second': 0.77, 'train_loss': 0.43941685259342195, 'epoch': 5.0}
100%|███████████████████████████████████████████████████████████████████████| 160/160 [03:27<00:00,  1.30s/it]
  0%|                                                                                 | 0/504 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/Users/qzhu8/GitHub/REaLTabFormer/0.py", line 19, in <module>
    rtf_model.fit(df)
  File "/Users/qzhu8/miniconda3/lib/python3.11/site-packages/realtabformer/realtabformer.py", line 458, in fit
    trainer = self._train_with_sensitivity(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/qzhu8/miniconda3/lib/python3.11/site-packages/realtabformer/realtabformer.py", line 714, in _train_with_sensitivity
    gen_df = self.sample(n_samples=gen_rounds * gen_total, device=device)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/qzhu8/miniconda3/lib/python3.11/site-packages/realtabformer/realtabformer.py", line 1244, in sample
    synth_df = tabular_sampler.sample_tabular(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/qzhu8/miniconda3/lib/python3.11/site-packages/realtabformer/rtf_sampler.py", line 633, in sample_tabular
    sample_outputs = self._generate(
                     ^^^^^^^^^^^^^^^
  File "/Users/qzhu8/miniconda3/lib/python3.11/site-packages/realtabformer/rtf_sampler.py", line 247, in _generate
    _samples = self.model.generate(**generate_kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/qzhu8/miniconda3/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/qzhu8/miniconda3/lib/python3.11/site-packages/transformers/generation/utils.py", line 1591, in generate
    model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/qzhu8/miniconda3/lib/python3.11/site-packages/transformers/generation/utils.py", line 468, in _prepare_attention_mask_for_generation
    raise ValueError(
ValueError: Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device.

It looks like it is complaining about the specification of attention mask. But I have no idea about this since I am new to this business. Can someone help?

avsolatorio commented 2 months ago

Hello @qzhu2017 , unfortunately, this is a currently open issue with pytorch on mps devices. See: https://github.com/pytorch/pytorch/issues/77764 and https://github.com/coqui-ai/TTS/issues/3758.