Open CZXIANGOvO opened 1 month ago
Is there a MVC? This code doesn't run for me
Is there a MVC? This code doesn't run for me
Is there a MVC? This code doesn't run for me
Where can't run it, in the beginning final_device to set it yourself, you can delete
if “CONTEXT_DEVICE_TARGET” in os.environ and os.environ['CONTEXT_DEVICE_TARGET'] == 'GPU': devices = os.environ['CUDA_VISIBLE_DEVICES'].split(“,”).
devices = os.environ['CUDA_VISIBLE_DEVICES'].split(“,”)
device = devices[-2]
final_device = “cuda:” + device
else: final_device = 'cuda:” + device
final_device = 'cpu'
Translated with DeepL.com (free version)
Hi @CZXIANGOvO – it's going to be hard to help with specifics here absent an MVC (also known as a minimal reproducible example). If you're able to re-work your example so that others can run it and see the same errors you are seeing, then we could offer specific guidance.
Absent that, though, in general it's not surprising to see NaN outputs for inputs without NaNs: it just means that you're calling some function in your model in a way that is undefined to floating point precision. Here's a simple example of this:
>>> import jax.numpy as jnp
>>> def f(x, y):
... return x * jnp.exp(y)
>>> f(1.0, 1.0)
Array(2.7182817, dtype=float32, weak_type=True)
>>> f(0.0, 100.0)
Array(nan, dtype=float32, weak_type=True)
More than likely, somewhere in your model you have an expression that is evaluating to NaN for reasons like this.
The best way to debug this is to start digging-in to your model to figure out exactly where this is coming from. One way to do this is to enable the jax_debug_nans
flag, as described here: https://jax.readthedocs.io/en/latest/debugging/flags.html#jax-debug-nans-configuration-option-and-context-manager
I hope that helps get you on the right path!
Description
Please specify
cuda:0
at the very beginning.System info (python version, jaxlib version, accelerator, etc.)
Code and data links:https://drive.google.com/file/d/1-edrk7_sxSgdu7cmXQXf6JsT57xiG1Hb/view?usp=sharing