google-research / sam

Apache License 2.0
565 stars 72 forks source link

Error running examples #9

Open le000043 opened 3 years ago

le000043 commented 3 years ago

Hello, I encountered this error when running this command : python3 -m sam.sam_jax.train --dataset cifar10 --model_name WideResnet28x10 --output_dir /tmp/my_experiment --image_level_augmentations autoaugment --num_epochs 1 --sam_rho 0.05

Any helps would be greatly appreaciated


Traceback (most recent call last):
  File "/home/dat/sam/sam/sam_jax/train.py", line 160, in main
    model, state = load_imagenet_model.get_model(FLAGS.model_name,
  File "/home/dat/sam/sam/sam_jax/imagenet_models/load_model.py", line 129, in get_model
    raise ModelNameError('Unrecognized model name.')
sam.sam_jax.imagenet_models.load_model.ModelNameError: Unrecognized model name.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/dat/sam/sam/sam_jax/train.py", line 177, in <module>
    app.run(main)
  File "/home/dat/.local/lib/python3.8/site-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/home/dat/.local/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/dat/sam/sam/sam_jax/train.py", line 164, in main
    model, state = load_model.get_model(FLAGS.model_name,
  File "/home/dat/sam/sam/sam_jax/models/load_model.py", line 118, in get_model
    model, init_state = create_image_model(prng_key, batch_size, image_size,
  File "/home/dat/sam/sam/sam_jax/models/load_model.py", line 57, in create_image_model
    _, initial_params = module.init_by_shape(
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/base.py", line 220, in wrapper
    return super_fn(*args, **kwargs)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/base.py", line 220, in wrapper
    return super_fn(*args, **kwargs)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/base.py", line 493, in init_by_shape
    stochastic_rng = stochastic.make_rng()
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/stochastic.py", line 86, in make_rng
    return rng_frame.make_rng()
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/stochastic.py", line 47, in make_rng
    return random.fold_in(self.base_rng, self.counter)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/_src/random.py", line 289, in fold_in
    return _fold_in(key, jnp.uint32(data))
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 143, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/_src/api.py", line 426, in cache_miss
    out_flat = xla.xla_call(
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/core.py", line 1565, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/core.py", line 1556, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/core.py", line 1568, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/core.py", line 609, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 581, in _xla_call_impl
    return compiled_fun(*args)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 874, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: CUDA operation failed: cudaGetErrorString symbol not found.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/dat/sam/sam/sam_jax/train.py", line 177, in <module>
    app.run(main)
  File "/home/dat/.local/lib/python3.8/site-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/home/dat/.local/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/dat/sam/sam/sam_jax/train.py", line 164, in main
    model, state = load_model.get_model(FLAGS.model_name,
  File "/home/dat/sam/sam/sam_jax/models/load_model.py", line 118, in get_model
    model, init_state = create_image_model(prng_key, batch_size, image_size,
  File "/home/dat/sam/sam/sam_jax/models/load_model.py", line 57, in create_image_model
    _, initial_params = module.init_by_shape(
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/base.py", line 220, in wrapper
    return super_fn(*args, **kwargs)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/base.py", line 220, in wrapper
    return super_fn(*args, **kwargs)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/base.py", line 493, in init_by_shape
    stochastic_rng = stochastic.make_rng()
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/stochastic.py", line 86, in make_rng
    return rng_frame.make_rng()
  File "/home/dat/sam/venv/lib/python3.8/site-packages/flax/nn/stochastic.py", line 47, in make_rng
    return random.fold_in(self.base_rng, self.counter)
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/_src/random.py", line 289, in fold_in
    return _fold_in(key, jnp.uint32(data))
  File "/home/dat/sam/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 874, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
RuntimeError: CUDA operation failed: cudaGetErrorString symbol not found.
ssbin4 commented 1 year ago

Hi, I am getting a similar error message. Did you address the problem?