google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.14k stars 646 forks source link

Quickstart code failed things related to RNG #3376

Closed dinhanhx closed 1 year ago

dinhanhx commented 1 year ago

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

Name: flax
Version: 0.7.2
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: 
Author: 
Author-email: Flax team <flax-dev@google.com>
License: 
Location: /opt/conda/lib/python3.10/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by: clu
---
Name: jax
Version: 0.4.16
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /opt/conda/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, clu, flax, optax, orbax-checkpoint, tensorflow
---
Name: jaxlib
Version: 0.4.16+cuda11.cudnn86
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /opt/conda/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, clu, optax, orbax-checkpoint

Problem you have encountered:

When I simply run Quickstart code, it fails at Step 9

What you expected to happen:

It works

Logs, error messages, etc:

---------------------------------------------------------------------------
InvalidRngError                           Traceback (most recent call last)
Cell In[34], line 1
----> 1 state = create_train_state(cnn, init_rng, learning_rate, momentum)
      2 del init_rng  # Must not be used anymore.

Cell In[33], line 6, in create_train_state(module, rng, learning_rate, momentum)
      4 def create_train_state(module, rng, learning_rate, momentum):
      5     """Creates an initial `TrainState`."""
----> 6     params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image
      7     tx = optax.sgd(learning_rate, momentum)
      8     return TrainState.create(
      9       apply_fn=module.apply, params=params, tx=tx,
     10       metrics=Metrics.empty())

    [... skipping hidden 3 frame]

File /opt/conda/lib/python3.10/site-packages/flax/linen/module.py:1729, in Module.init_with_output(self, rngs, method, mutable, capture_intermediates, *args, **kwargs)
   1727 if not isinstance(rngs, dict):
   1728   if not core.scope._is_valid_rng(rngs):
-> 1729     raise errors.InvalidRngError(
   1730         'RNGs should be of shape (2,) or KeyArray in Module '
   1731         f'{self.__class__.__name__}, but rngs are: {rngs}'
   1732     )
   1733   rngs = {'params': rngs}
   1735 if isinstance(method, str):

InvalidRngError: RNGs should be of shape (2,) or KeyArray in Module CNN, but rngs are: Array((), dtype=key<fry>) overlaying:
[0 0] (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.InvalidRngError)
add Codeadd Markdown

Steps to reproduce:

https://www.kaggle.com/inhanhv/jax-flax-hello-world

8bitmp3 commented 1 year ago

Hi @dinhanhx can you check if you are running the latest version of Flax? Optax should be installed with Flax via pip.

Reproduced the code in Kaggle with P100 GPU with no errors. You may need to click on Run -> Restate and clear cell outputs after you upgrade with !pip install -U flax jax jaxlib

jax 0.4.16 jaxlib 0.4.16 flax 0.7.4 optax 0.1.7

!pip show flax jax jaxlib optax

Name: flax
Version: 0.7.4
...
Required-by: clu
---
Name: jax
Version: 0.4.16
...
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, clu, flax, optax, orbax-checkpoint, tensorflow
---
Name: jaxlib
Version: 0.4.16
...
Requires: ml-dtypes, numpy, scipy
Required-by: chex, clu, optax, orbax-checkpoint
---
Name: optax
Version: 0.1.7
...
Requires: absl-py, chex, jax, jaxlib, numpy
Required-by: flax
dinhanhx commented 1 year ago

@8bitmp3 I just have upgraded everything and tried "Run -> Restate and clear cell outputs after you upgrade with". However new weird stuff happens,

https://www.kaggle.com/code/inhanhv/jax-flax-hello-world

the new error happens at that RNG line,

why XLA and TPU stuff appear on solely-GPU enviroment?

2023-09-30 11:37:56.723066: E external/xla/xla/stream_executor/cuda/cuda_blas.cc:188] failed to create cublas handle: the resource allocation failed
2023-09-30 11:37:56.723102: E external/xla/xla/stream_executor/cuda/cuda_blas.cc:191] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your deep-learning framework may have preallocated more than its fair share), or may be because this binary was not built with support for the GPU in your machine.
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[20], line 1
----> 1 state = create_train_state(cnn, init_rng, learning_rate, momentum)
      2 del init_rng  # Must not be used anymore.

Cell In[13], line 6, in create_train_state(module, rng, learning_rate, momentum)
      4 def create_train_state(module, rng, learning_rate, momentum):
      5     """Creates an initial `TrainState`."""
----> 6     params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image
      7     tx = optax.sgd(learning_rate, momentum)
      8     return TrainState.create(
      9       apply_fn=module.apply, params=params, tx=tx,
     10       metrics=Metrics.empty())

    [... skipping hidden 9 frame]

Cell In[11], line 7, in CNN.__call__(self, x)
      5 @nn.compact
      6 def __call__(self, x):
----> 7     x = nn.Conv(features=32, kernel_size=(3, 3))(x)
      8     x = nn.relu(x)
      9     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

    [... skipping hidden 2 frame]

File /opt/conda/lib/python3.10/site-packages/flax/linen/linear.py:524, in _Conv.__call__(self, inputs)
    522   else:
    523     conv_general_dilated = self.conv_general_dilated
--> 524   y = conv_general_dilated(
    525       inputs,
    526       kernel,
    527       strides,
    528       padding_lax,
    529       lhs_dilation=input_dilation,
    530       rhs_dilation=kernel_dilation,
    531       dimension_numbers=dimension_numbers,
    532       feature_group_count=self.feature_group_count,
    533       precision=self.precision,
    534   )
    535 else:
    536   y = lax.conv_general_dilated_local(
    537       lhs=inputs,
    538       rhs=kernel,
   (...)
    545       precision=self.precision,
    546   )

    [... skipping hidden 13 frame]

File /opt/conda/lib/python3.10/site-packages/jax/_src/compiler.py:251, in backend_compile(backend, module, options, host_callbacks)
    246   return backend.compile(built_c, compile_options=options,
    247                          host_callbacks=host_callbacks)
    248 # Some backends don't have `host_callbacks` option yet
    249 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    250 # to take in `host_callbacks`
--> 251 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv.1 = (f32[1,32,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,1,28,28]{3,2,1,0} %bitcast.3, f32[32,1,3,3]{3,2,1,0} %transpose), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_263/2052264800.py" source_line=7}, backend_config={"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}

Original error: INTERNAL: All algorithms tried for (f32[1,32,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,1,28,28]{3,2,1,0}, f32[32,1,3,3]{3,2,1,0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0} failed. Falling back to default algorithm.  Per-algorithm errors:
  Profiling failure on cuDNN engine eng28{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng34{k2=2,k4=1,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng1{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng34{k2=1,k4=1,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng2{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng2{k2=3,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng34{k2=1,k4=3,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng1{k2=2,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng34{k2=2,k4=2,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng28{k2=1,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng34{k2=0,k4=2,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng42{k2=1,k4=1,k5=0,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng4{}: UNKNOWN: CUDNN_STATUS_INTERNAL_ERROR
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng28{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng34{k2=2,k4=1,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng1{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng34{k2=1,k4=1,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng2{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng2{k2=3,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng34{k2=1,k4=3,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng1{k2=2,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng34{k2=2,k4=2,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng28{k2=1,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng34{k2=0,k4=2,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng42{k2=1,k4=1,k5=0,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng4{}: UNKNOWN: CUDNN_STATUS_INTERNAL_ERROR
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng1{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng28{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng42{k2=2,k4=1,k5=0,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng42{k2=2,k4=2,k5=0,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng42{k2=0,k4=1,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'
  Profiling failure on cuDNN engine eng3{k11=2}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6784): 'status'

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.
dinhanhx commented 1 year ago

This is what happened when I use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false

2023-09-30 11:48:53.530029: E external/xla/xla/stream_executor/cuda/cuda_blas.cc:188] failed to create cublas handle: the resource allocation failed
2023-09-30 11:48:53.530067: E external/xla/xla/stream_executor/cuda/cuda_blas.cc:191] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your deep-learning framework may have preallocated more than its fair share), or may be because this binary was not built with support for the GPU in your machine.
2023-09-30 11:48:53.988144: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2644] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.conv.forward' failed: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6342): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'; current tracing scope: cudnn-conv.1; current profiling annotation: XlaModule:#hlo_module=jit_conv_general_dilated,program_id=11#.
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[18], line 1
----> 1 state = create_train_state(cnn, init_rng, learning_rate, momentum)
      2 del init_rng  # Must not be used anymore.

Cell In[11], line 6, in create_train_state(module, rng, learning_rate, momentum)
      4 def create_train_state(module, rng, learning_rate, momentum):
      5     """Creates an initial `TrainState`."""
----> 6     params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image
      7     tx = optax.sgd(learning_rate, momentum)
      8     return TrainState.create(
      9       apply_fn=module.apply, params=params, tx=tx,
     10       metrics=Metrics.empty())

    [... skipping hidden 9 frame]

Cell In[9], line 7, in CNN.__call__(self, x)
      5 @nn.compact
      6 def __call__(self, x):
----> 7     x = nn.Conv(features=32, kernel_size=(3, 3))(x)
      8     x = nn.relu(x)
      9     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

    [... skipping hidden 2 frame]

File /opt/conda/lib/python3.10/site-packages/flax/linen/linear.py:524, in _Conv.__call__(self, inputs)
    522   else:
    523     conv_general_dilated = self.conv_general_dilated
--> 524   y = conv_general_dilated(
    525       inputs,
    526       kernel,
    527       strides,
    528       padding_lax,
    529       lhs_dilation=input_dilation,
    530       rhs_dilation=kernel_dilation,
    531       dimension_numbers=dimension_numbers,
    532       feature_group_count=self.feature_group_count,
    533       precision=self.precision,
    534   )
    535 else:
    536   y = lax.conv_general_dilated_local(
    537       lhs=inputs,
    538       rhs=kernel,
   (...)
    545       precision=self.precision,
    546   )

    [... skipping hidden 8 frame]

File /opt/conda/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1149, in ExecuteReplicated.__call__(self, *args)
   1147   self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1148 else:
-> 1149   results = self.xla_executable.execute_sharded(input_bufs)
   1150 if dispatch.needs_check_special():
   1151   out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.conv.forward' failed: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6342): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'; current tracing scope: cudnn-conv.1; current profiling annotation: XlaModule:#hlo_module=jit_conv_general_dilated,program_id=11#.
8bitmp3 commented 1 year ago

@dinhanhx Able to start training the model in Kaggle with a P100 GPU here. Can you try opening a new Kaggle notebook with P100. Then, start with !pip install -U flax jax jaxlib, followed by the code cells copied from the https://flax.readthedocs.io/en/latest/getting_started.html. Try running them line by line and let us know if you are still experiencing any issues.

dinhanhx commented 1 year ago

@8bitmp3 the quickstart code is executing now, but it's slow and doesn't use GPU, image

Do I have to move something to GPU like PyTorch?

Source code: https://www.kaggle.com/inhanhv/flax-quickstart

dinhanhx commented 1 year ago

Okay I'm dumb, so I need to install jax with cuda then restart the runtime to use GPU

pip install -U flax jax jaxlib
pip install -q --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
dinhanhx commented 1 year ago

Thanks for your help,

Name: flax
Version: 0.7.4
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: 
Author: 
Author-email: Flax team <flax-dev@google.com>
License: 
Location: /opt/conda/lib/python3.10/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by: clu
---
Name: jax
Version: 0.4.18
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /opt/conda/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, clu, flax, optax, orbax-checkpoint, tensorflow
---
Name: jaxlib
Version: 0.4.18+cuda11.cudnn86
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /opt/conda/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, clu, optax, orbax-checkpoint
---
Name: optax
Version: 0.1.7
Summary: A gradient processing and optimisation library in JAX.
Home-page: 
Author: 
Author-email: DeepMind <optax-dev@google.com>
License: 
Location: /opt/conda/lib/python3.10/site-packages
Requires: absl-py, chex, jax, jaxlib, numpy
Required-by: flax
8bitmp3 commented 1 year ago

:+1: You are welcome @dinhanhx :+1: If you have any more questions, feel free to create a new GitHub Issue, Discussion or reopen this Issue.