dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
77 stars 28 forks source link

Kernel crashing for large datasets initialization #177

Open JSignoretGenest opened 3 weeks ago

JSignoretGenest commented 3 weeks ago

Hello,

We've been successfully using keypoint-moseq on several datasets, but we are encountering a consistent issue with larger datasets, where the kernel crashes during model initialization.

Steps Taken to Isolate the Issue:

  1. Tested with smaller datasets: the model initializes and runs fine.
  2. Simulated larger dataset: we artificially inflated previously working datasets to match the size of the problematic dataset, and the kernel crash occurred, confirming the issue is size-related and not specific to the data itself.
  3. Isolated problematic function: we traced the issue to _initstates in jax_moseq/models/arhmm/initialize.py, particularly in the function _resample_discretestateseqs.

Behavior Observed:

  1. Without @jax.jit on _initstates: the kernel dies after processing log_likelihoods correctly but before returning it to _initstates. We added print statements for debugging, and the crash occurs between the successful processing and before the value is returned to _initstates.
  2. With @jax.jit: the issue is deferred to the return of z. Again, z is processed with the correct array dimensions (as confirmed by logs), but the kernel crashes upon returning the value to _initstates. Not returning the value prevents the crash, and _initstates continues to execute normally right after, which suggests that the crash happens during the return.
  3. JAX Debugging: for working datasets, jax.debug.print successfully outputs e.g. array shapes, but for datasets that cause the crash, jax.debug.print produces no output, even though it effectively goes to the following steps.

Attempts to Work Around the Issue:

  1. Manual batching and vmap/lax batching: we tried manually processing z in smaller batches or with vmap/lax to reduce the size of the returned array. However, the kernel still crashes, now when _initstates returns z.
  2. System Monitoring: we did not observe any VRAM or RAM saturation, so memory exhaustion does not seem to be the cause.

We would be grateful for any workaround for that step, even if that translates to a slower initialization!

Thank you for your help!

calebweinreb commented 3 weeks ago

Thanks for the thorough report! Given that the crash is strictly a function of dataset size, it does point strongly in the direction of VRAM saturation. What OS are you using? I have found that OOM errors on Windows cause the kernel to crash without an explicit OOM error.

JSignoretGenest commented 3 weeks ago

Thanks for the quick reply!

We are on Windows indeed, and that does look a lot like what we get if the set_mixed_map_iters value for the fitting part is too small - here we only checked the resources visually and with frequent nvidia-smi, but there was no obvious trend or peak detectable (perhaps it's really brief when attempting to return the values, I wouldn't 100% exclude it).