tensorflow / tensor2tensor

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
Apache License 2.0
15.5k stars 3.49k forks source link

Tensor2tensor can't run on google colab #1650

Open zhez6 opened 5 years ago

zhez6 commented 5 years ago

Description

I've run the following code in google colab with GPU, and it got stuck after printing 'Saving checkpoints for 0 into /content/test/model.ckpt.' Any ideas?

!pip install -q tensor2tensor==1.13

!apt-get install -q sox

!t2t-trainer \ --generate_data \ --tmp_dir=/content/librispeech_clean_small/tmp/ \ --problem=librispeech_clean_small \ --model=transformer \ --train_steps=50000 \ --hparams_set=transformer_librispeech \ --data_dir=/content/librispeech_clean_small/ \ --output_dir=/content/test \ --hparams="batch_size=100000,optimizer="Adafactor",learning_rate_schedule="rsqrt_decay" " \ --local_eval_frequency=10000 ...

Environment information

OS: <your answer here>

$ pip freeze | grep tensor
# your output here

$ python -V
# your output here

For bugs: reproduction and error logs

# Steps to reproduce:
...
# Error logs:
...
Victor-Almeida commented 5 years ago

I'm having the same problem and it seems like a Tensorflow issue when using GPU on Google Colab. Use a normal runtime and the training will proceed without issues.

DevZiegler commented 4 years ago

I can run it with this bugfix, in the Python3 GPU environment. You need to customize this in the common_audio.py file.

def add_delta_deltas(filterbanks, name=None):
****
  #bugfix (GPU not supported, run this layer on CPU)
  with tf.device('/cpu:0'):
    filterbanks = tf.nn.conv2d(
      filterbanks, delta_filter_stack, [1, 1, 1, 1], "SAME", data_format="NHWC",
      name=name)
  # filterbanks = tf.nn.conv2d(
  #     filterbanks, delta_filter_stack, [1, 1, 1, 1], "SAME", data_format="NHWC",
  #     name=name)
  return filterbanks

edit: plus I installed all that

!pip3 install -q -U tensor2tensor
!pip3 install -q matplotlib
!pip3 install -q gast==0.2.2
!apt-get install sox
!apt-get install libsox-fmt-mp3