vzyrianov / LidarDM

https://www.zyrianov.org/lidardm
110 stars 9 forks source link

Failed to getting a Waymax Scenario in waymax jupyter-notebook demo #4

Closed ryan-utopia closed 4 months ago

ryan-utopia commented 4 months ago

Hello! I have set up the environment, downloaded Waymo-weights and motion tfrecord-00000-of-01000. image successfully running this cell; however, while running get scenario cell, A serious error happened like this: W tensorflow/core/framework/op_kernel.cc:1830] OP_REQUIRES failed at example_parsing_ops.cc:94 : INVALID_ARGUMENT: Key: roadgraph_samples/xyz. Can't parse serialized Example.

I have redownloaded the tfrecord and checked the environment for several times but it didn't work at all.

ryan-utopia commented 4 months ago

Full traceback bellow: `--------------------------------------------------------------------------- InvalidArgumentError Traceback (most recent call last) Cell In[9], line 5 3 seed = 78 # change this to get different scenario 4 for _ in range(seed): ----> 5 scenario = next(data_iter) 7 # visualize the maps 8 maps = []

File ~/.conda/envs/lidardm/lib/python3.10/site-packages/waymax/dataloader/dataloader_utils.py:229, in get_data_generator(config, preprocess_fn, postprocess_fn) 227 else: 228 postprocess_fn = jax.jit(postprocess_fn) --> 229 for example in dataset.as_numpy_iterator(): 230 yield postprocess_fn(example)

File ~/.conda/envs/lidardm/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py:4770, in _NumpyIterator.next(self) 4767 numpy.setflags(write=False) 4768 return numpy -> 4770 return nest.map_structure(to_numpy, next(self._iterator))

File ~/.conda/envs/lidardm/lib/python3.10/site-packages/tensorflow/python/data/ops/iterator_ops.py:787, in OwnedIterator.next(self) 785 def next(self): 786 try: --> 787 return self._next_internal() 788 except errors.OutOfRangeError: 789 raise StopIteration

File ~/.conda/envs/lidardm/lib/python3.10/site-packages/tensorflow/python/data/ops/iterator_ops.py:770, in OwnedIterator._next_internal(self) 767 # TODO(b/77291417): This runs in sync mode as iterators use an error status 768 # to communicate that there is no more data to iterate over. 769 with context.execution_mode(context.SYNC): --> 770 ret = gen_dataset_ops.iterator_get_next( 771 self._iterator_resource, 772 output_types=self._flat_output_types, 773 output_shapes=self._flat_output_shapes) 775 try: 776 # Fast path for the case self._structure is not a nested structure. 777 return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access

File ~/.conda/envs/lidardm/lib/python3.10/site-packages/tensorflow/python/ops/gen_dataset_ops.py:3017, in iterator_get_next(iterator, output_types, output_shapes, name) 3015 return _result 3016 except _core._NotOkStatusException as e: -> 3017 _ops.raise_from_not_ok_status(e, name) 3018 except _core._FallbackException: 3019 pass

File ~/.conda/envs/lidardm/lib/python3.10/site-packages/tensorflow/python/framework/ops.py:7215, in raise_from_not_ok_status(e, name) 7213 def raise_from_not_ok_status(e, name): 7214 e.message += (" name: " + name if name is not None else "") -> 7215 raise core._status_to_exception(e) from None

InvalidArgumentError: {{function_node wrappedIteratorGetNext_output_types_31device/job:localhost/replica:0/task:0/device:CPU:0}} Key: roadgraph_samples/xyz. Can't parse serialized Example. [[{{node ParseExample/ParseExampleV2}}]] [Op:IteratorGetNext]`

hungdche commented 4 months ago

Thank you for your interest in our work. Can you send me the link to the tfrecord you downloaded so I can test on my end?

ryan-utopia commented 4 months ago

Thanks, you can get it from this gs link: gs://waymo_open_dataset_motion_v_1_2_1/uncompressed/tf_example/training/training_tfexample.tfrecord-00000-of-01000

I followed the guide in "notebook_waymax.ipynb" and downloaded it :

# download WOMD tfrecord
# 1. register with Waymo Open Motion Dataaset
# 2. in the Google Cloud, download waymo_open_dataset_motion_v_1_2_1/uncompressed/
#                 tf_example/training/training_tfexample.tfrecord-00000-of-01000
# 3. move the downloaded file into _datasets/womd
hungdche commented 4 months ago

Could you try WOMD v1.1.0 instead? I might have mistyped the version in the instructions. My apologies for that.

ryan-utopia commented 4 months ago

Thanks a lot. WOMD v1.1.0 work well now, but I run this demo on RTX4090 which only have 24GB Memory, while Generating scene self.mesh = self.generate_mesh(), it produce OutOfMemoryError: CUDA out of memory image Is there a way for this demo to run on 1 or more RTX4090s? Or is it only possible to use other GPUs with more memory?

hungdche commented 4 months ago

May I know the line it throws CUDA OOM and the seed you use? (seed is the number I (in second or third notebook cell, which is defaulted to 68)?

ryan-utopia commented 4 months ago

image It‘s seem to be this line (LidarDM/lidardm/waymax/waymax_compositor.py:156) current_mesh = sample_from_map(torch.from_numpy(decoded_map).cuda().float(), self.model.cuda()) I have tried the seed of 68、11、16……

hungdche commented 4 months ago

That is a bit weird as everything works fine in my RTX 4080 with 8GB GPU memory. Do you have any background process that is eating up the GPU?

ryan-utopia commented 4 months ago

Exactly, After running the follow code, 22GB memory was occupied and didn’t free. # get scenario data_iter = dataloader.simulator_state_generator(config=config) seed = 68 # change this to get different scenario for _ in range(seed): scenario = next(data_iter)

hungdche commented 4 months ago

I think you might face similar issue here: https://github.com/waymo-research/waymax/issues/55. More than likely your JAX setup is putting the whole tfrecord to GPU, causing OOM.

hungdche commented 4 months ago

I guess my env is configured for GPU-basd JAX as I had this warning:

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
ryan-utopia commented 4 months ago

I have this warning too, let me take a closer look.

ryan-utopia commented 4 months ago

I have tried this jax.config.update('jax_transfer_guard', 'allow') but it didn‘t work, may I know the jax version you use in your env?

hungdche commented 4 months ago

Mine is 0.4.21:

jax                       0.4.21                   pypi_0    pypi
jaxlib                    0.4.21                   pypi_0    pypi
ryan-utopia commented 4 months ago

Thanks a lot! Just solved the problem.

  1. For the OOM error, it‘s not cause by jax but tensorflow. To solve this, the method is add tf.config.set_visible_devices([], 'GPU') to the file womd_dataloader.py

  2. While your computer has multi GPUS for infer, should change the
    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") to a specific GPU like self.device = torch.device("cuda:0"), otherwise it may cause RuntimeError: CUDA error: an illegal memory access was encountered

  3. The variable max_deccel should be max_decel in the file waymax_agent.py

hungdche commented 4 months ago

Sounds good. If you want, can you create a PR with your changes (and the fix to the WOMD version)? @ryan-utopia

ryan-utopia commented 4 months ago

Okay, I'll finish it later.