google-research / xmcgan_image_generation

97 stars 15 forks source link

About jax.PRNGKey:Error reporting when running #8

Open euyy opened 2 years ago

euyy commented 2 years ago

Excuse me. When I tried to run this code, I have a problem about this line: https://github.com/google-research/xmcgan_image_generation/blob/22a7ef2914787904949fe1fc3f5e560f1e75db29/xmcgan/train_utils.py#L167 and the error is flax.errors.InvalidRngError: rngs should be a dictionary mapping strings to jax.PRNGKey. Actually, the g_rng is a array of shape[2,]. So anyone else can help me solve this problem?

By the way, I have configured cuda, but it still tell me cuda not found. xla_bridge.py:232] Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host I even use tensorflow to test the gpu, which is right. I don't know what the problem is.

woctezuma commented 2 years ago

No idea about the second issue, but you can find others having the same issue online: https://github.com/kuixu/alphafold/issues/8

kohjingyu commented 2 years ago

The first issue might be because of some recent change to Flax. What version are you using? Can you try changing it to:

generator_variables = generator(train=False).init({'params': g_rng}, (inputs, z))

(ref: https://github.com/google/flax/blob/main/examples/imagenet/train.py#L74)

The second issue is likely due to some problem during setup. Can you perhaps try asking in https://github.com/google/jax.

hyeonjinXZ commented 2 years ago

The first issue is fixed for me by upgrading to the latest version of Flax. pip install --upgrade git+https://github.com/google/flax.git (ref: https://pythonrepo.com/repo/google-flax-python-deep-learning)

euyy commented 2 years ago

@woctezuma @Hyeonjin1989 @kohjingyu Thanks for your help. But now I have a new problem. UNKNOWN: Failed to determine best cudnn convolution algorithm: UNKNOWN: GetConvolveAlgorithms failed. Convolution performance may be suboptimal. To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning. I don't know if it's the error caused by my device. So I want to know if there are minimum configuration requirements for training. If anyone knows about it, please tell me. Thanks.

adambot806 commented 2 years ago

add

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = '.7'

explanation may refer to gpu memory allocation.