google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.01k stars 813 forks source link

Could not normally run trax using GPU in local computer #1778

Open LiuZhenshun opened 1 year ago

LiuZhenshun commented 1 year ago

Description

Hi, I would like to install trax locally. Firstly, I found jax I installed is not suitable for GPU, so I follow the jax github to install Cuda version jax. Next, I validate jax could detect GPU in my local computer, but I could not run the sample code like transfomer and fast math.

Environment information

OS: Pop-os(based on ubuntu 22.04)

$ pip freeze | grep trax
# trax==1.4.1

$ pip freeze | grep tensor
# tensorboard==2.12.3
# tensorboard-data-server==0.7.1
# tensorflow==2.12.0
# tensorflow-datasets==4.9.2
# tensorflow-estimator==2.12.0
# tensorflow-hub==0.13.0
# tensorflow-io-gcs-filesystem==0.32.0
# tensorflow-metadata==1.13.1
# tensorflow-text==2.12.1

$ pip freeze | grep jax
# jax==0.4.12
# jaxlib==0.4.12+cuda11.cudnn86

$ python -V
# Python 3.11.3

For bugs: reproduction and error logs

# Steps to reproduce:
1) Install trax

- pip install trax
- pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

2) Use jax Detect GPU
- code:
    import jax
    print(jax.devices()) 
- output:
    [gpu(id=0)]
# Error logs:
1) Run the sample code of pre-trained transformer in your Realme tutorial
- code:
      import os
      import numpy as np

      import trax

      # Create a Transformer model.
      # Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
      model = trax.models.Transformer(
          input_vocab_size=33300,
          d_model=512, d_ff=2048,
          n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
          max_len=64, mode='predict')

      # Initialize using pre-trained weights.
      model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                           weights_only=True)
                          #  input_signature=input_signature)

      # Tokenize a sentence.
      sentence = 'It is nice to learn new things today!'
      tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                          vocab_dir='gs://trax-ml/vocabs/',
                                          vocab_file='ende_32k.subword'))[0]

      # Decode from the Transformer.
      tokenized = tokenized[None, :]  # Add batch dimension.
      tokenized_translation = trax.supervised.decoding.autoregressive_sample(
          model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

      # De-tokenize,
      tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
      translation = trax.data.detokenize(tokenized_translation,
                                         vocab_dir='gs://trax-ml/vocabs/',
                                         vocab_file='ende_32k.subword')
      print(translation)
- Error Output:
      2023-06-22 15:58:35.266959: W tensorflow/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".
      2023-06-22 15:58:56.630331: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
      INTERNAL: Failed to get stream's capture status: out of memory
      2023-06-22 15:58:56.630403: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2461] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to get stream's capture status: out of memory; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit_PRNGKey,program_id=0#.
      Traceback (most recent call last):
        File "/home/littleliu/Documents/project/trax_learning/tryTrax.py", line 22, in <module>
          model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/trax/layers/base.py", line 349, in init_from_file
          self.init(input_signature)
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/trax/layers/base.py", line 310, in init
          raise LayerError(name, 'init', self._caller,
      trax.layers.base.LayerError: Exception passing through layer Serial (in init):
        layer created in file [...]/trax/models/transformer.py, line 371
        layer input shapes: (ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:float32})

        File [...]/trax/layers/combinators.py, line 108, in init_weights_and_state
          outputs, _ = sublayer._forward_abstract(inputs)

                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File [...]/trax/layers/base.py, line 641, in _forward_abstract

        layer created in file [...]/trax/models/transformer.py, line 372
        layer input shapes: (ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64})

      jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to get stream's capture status: out of memory; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit_PRNGKey,program_id=0#.

2) Run the sample code of Fast Math:
- code:
      import trax
      from trax.fastmath import numpy as fastnp
      trax.fastmath.use_backend('jax')  # Can be 'jax' or 'tensorflow-numpy'.

      matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
      print(f'matrix =\n{matrix}')
      vector = fastnp.ones(3)
      print(f'vector = {vector}')
      product = fastnp.dot(vector, matrix)
      print(f'product = {product}')
      tanh = fastnp.tanh(product)
      print(f'tanh(product) = {tanh}')
- Error Output:
      matrix =
      [[1 2 3]
       [4 5 6]
       [7 8 9]]
      2023-06-22 16:03:23.041313: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
      2023-06-22 16:03:23.041386: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 36175872 bytes free, 4093902848 bytes total.
      2023-06-22 16:03:23.041476: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:453] Possibly insufficient driver version: 525.85.5
      Traceback (most recent call last):
        File "/home/littleliu/Documents/project/trax_learning/fastnumpy.py", line 7, in <module>
          vector = fastnp.ones(3)
                   ^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2161, in ones
          return lax.full(shape, 1, _jnp_dtype(dtype))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 1205, in full
          return broadcast(fill_value, shape)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 768, in broadcast
          return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 796, in broadcast_in_dim
          return broadcast_in_dim_p.bind(
                 ^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 380, in bind
          return self.bind_with_trace(find_top_trace(args), args, params)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 383, in bind_with_trace
          out = trace.process_primitive(self, map(trace.full_raise, args), params)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 790, in process_primitive
          return primitive.impl(*tracers, **params)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 132, in apply_primitive
          compiled_fun = xla_primitive_callable(
                         ^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/util.py", line 284, in wrapper
          return cached(config._trace_context(), *args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/util.py", line 277, in cached
          return f(*args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 223, in xla_primitive_callable
          compiled = _xla_callable_uncached(
                     ^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 253, in _xla_callable_uncached
          return computation.compile().unsafe_call
                 ^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2329, in compile
          executable = UnloadedMeshExecutable.from_hlo(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2651, in from_hlo
          xla_executable, compile_options = _cached_compilation(
                                            ^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2561, in _cached_compilation
          xla_executable = dispatch.compile_or_get_cached(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
          return backend_compile(backend, computation, compile_options,
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
          return func(*args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
          return backend.compile(built_c, compile_options=options)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.