LorenFrankLab / spyglass

Neuroscience data analysis framework for reproducible research built by Loren Frank Lab at UCSF
https://lorenfranklab.github.io/spyglass/
MIT License
94 stars 43 forks source link

gpu not working for non local decode on zephyr #1191

Closed MichaelCoulter closed 1 day ago

MichaelCoulter commented 2 days ago

i installed non-local-detector[gpu] via pip and tried to run ClusterlessDecodingV1.populate(selection_key)

then i got this error. i am using breeze.

/home/mcoulter/anaconda3/envs/spyglass2/lib/python3.9/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'hdmf-common' version 1.5.1 because version 1.6.0 is already loaded.
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
/home/mcoulter/anaconda3/envs/spyglass2/lib/python3.9/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'core' version 2.4.0 because version 2.6.0-alpha is already loaded.
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
/home/mcoulter/anaconda3/envs/spyglass2/lib/python3.9/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'hdmf-experimental' version 0.2.0 because version 0.3.0 is already loaded.
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
19-Nov-24 17:16:52 Fitting initial conditions...
19-Nov-24 17:16:52 Fitting discrete state transition
19-Nov-24 17:16:52 Fitting continuous state transition...
19-Nov-24 17:16:53 Fitting clusterless spikes...
19-Nov-24 17:16:56 Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
19-Nov-24 17:16:56 Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
E1119 17:16:56.134043 3883135 cuda_dnn.cc:503] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
E1119 17:16:56.134715 3883135 cuda_dnn.cc:503] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In [4], line 21
     14 ClusterlessDecodingSelection.insert1(
     15     selection_key,
     16     skip_duplicates=True,
     17 )
     19 ClusterlessDecodingSelection & selection_key
---> 21 ClusterlessDecodingV1.populate(selection_key)

File ~/spyglass/src/spyglass/utils/dj_mixin.py:589, in SpyglassMixin.populate(self, *restrictions, **kwargs)
    587 if use_transact:  # Pass single-process populate to super
    588     kwargs["processes"] = processes
--> 589     return super().populate(*restrictions, **kwargs)
    590 else:  # No transaction protection, use bare make
    591     for key in keys:

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/datajoint/autopopulate.py:248, in AutoPopulate.populate(self, suppress_errors, return_exception_objects, reserve_jobs, order, limit, max_calls, display_progress, processes, make_kwargs, *restrictions)
    242 if processes == 1:
    243     for key in (
    244         tqdm(keys, desc=self.__class__.__name__)
    245         if display_progress
    246         else keys
    247     ):
--> 248         status = self._populate1(key, jobs, **populate_kwargs)
    249         if status is True:
    250             success_list.append(1)

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/datajoint/autopopulate.py:315, in AutoPopulate._populate1(self, key, jobs, suppress_errors, return_exception_objects, make_kwargs)
    313 self.__class__._allow_insert = True
    314 try:
--> 315     make(dict(key), **(make_kwargs or {}))
    316 except (KeyboardInterrupt, SystemExit, Exception) as error:
    317     try:

File ~/spyglass/src/spyglass/decoding/v1/clusterless.py:212, in ClusterlessDecodingV1.make(self, key)
    200 VALID_FIT_KWARGS = [
    201     "is_training",
    202     "encoding_group_labels",
    203     "environment_labels",
    204     "discrete_transition_covariate_data",
    205 ]
    207 fit_kwargs = {
    208     key: value
    209     for key, value in decoding_kwargs.items()
    210     if key in VALID_FIT_KWARGS
    211 }
--> 212 classifier.fit(
    213     position_time=position_info.index.to_numpy(),
    214     position=position_info[position_variable_names].to_numpy(),
    215     spike_times=spike_times,
    216     spike_waveform_features=spike_waveform_features,
    217     **fit_kwargs,
    218 )
    219 VALID_PREDICT_KWARGS = [
    220     "is_missing",
    221     "discrete_transition_covariate_data",
    222     "return_causal_posterior",
    223 ]
    224 predict_kwargs = {
    225     key: value
    226     for key, value in decoding_kwargs.items()
    227     if key in VALID_PREDICT_KWARGS
    228 }

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/models/base.py:1475, in ClusterlessDetector.fit(self, position_time, position, spike_times, spike_waveform_features, is_training, encoding_group_labels, environment_labels, discrete_transition_covariate_data)
   1441 """
   1442 Fit the detector to the data.
   1443 
   (...)
   1466     Fitted detector instance.
   1467 """
   1468 self._fit(
   1469     position,
   1470     is_training,
   (...)
   1473     discrete_transition_covariate_data,
   1474 )
-> 1475 self.fit_encoding_model(
   1476     position_time,
   1477     position,
   1478     spike_times,
   1479     spike_waveform_features,
   1480     is_training,
   1481     encoding_group_labels,
   1482     environment_labels,
   1483 )
   1484 return self

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/models/base.py:1420, in ClusterlessDetector.fit_encoding_model(self, position_time, position, spike_times, spike_waveform_features, is_training, encoding_group_labels, environment_labels)
   1413 is_group = is_training & is_encoding & is_environment
   1414 (
   1415     group_spike_times,
   1416     group_spike_waveform_features,
   1417 ) = self._get_group_spike_data(
   1418     spike_times, spike_waveform_features, is_group, position_time
   1419 )
-> 1420 self.encoding_model_[likelihood_name] = encoding_algorithm(
   1421     position_time[is_group],
   1422     position[is_group],
   1423     group_spike_times,
   1424     group_spike_waveform_features,
   1425     environment,
   1426     self.sampling_frequency,
   1427     **kwargs,
   1428 )

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/likelihoods/clusterless_kde.py:195, in fit_clusterless_kde_encoding_model(position_time, position, spike_times, spike_waveform_features, environment, sampling_frequency, position_std, waveform_std, block_size, disable_progress_bar)
    193 if isinstance(position_std, (int, float)):
    194     if environment.track_graph is not None and position.shape[1] > 1:
--> 195         position_std = jnp.array([position_std])
    196     else:
    197         position_std = jnp.array([position_std] * position.shape[1])

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:3214, in array(object, dtype, copy, order, ndmin)
   3211 else:
   3212   raise TypeError(f"Unexpected input type for array: {type(object)}")
-> 3214 out_array: Array = lax_internal._convert_element_type(
   3215     out, dtype, weak_type=weak_type)
   3216 if ndmin > ndim(out_array):
   3217   out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/jax/_src/lax/lax.py:559, in _convert_element_type(operand, new_dtype, weak_type)
    557   return type_cast(Array, operand)
    558 else:
--> 559   return convert_element_type_p.bind(operand, new_dtype=new_dtype,
    560                                      weak_type=bool(weak_type))

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/jax/_src/core.py:416, in Primitive.bind(self, *args, **params)
    413 def bind(self, *args, **params):
    414   assert (not config.enable_checks.value or
    415           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 416   return self.bind_with_trace(find_top_trace(args), args, params)

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/jax/_src/core.py:420, in Primitive.bind_with_trace(self, trace, args, params)
    418 def bind_with_trace(self, trace, args, params):
    419   with pop_level(trace.level):
--> 420     out = trace.process_primitive(self, map(trace.full_raise, args), params)
    421   return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/jax/_src/core.py:921, in EvalTrace.process_primitive(self, primitive, tracers, params)
    919   return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
    920 else:
--> 921   return primitive.impl(*tracers, **params)

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/jax/_src/dispatch.py:87, in apply_primitive(prim, *args, **params)
     85 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     86 try:
---> 87   outs = fun(*args)
     88 finally:
     89   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 15 frame]

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/jax/_src/compiler.py:238, in backend_compile(backend, module, options, host_callbacks)
    233   return backend.compile(built_c, compile_options=options,
    234                          host_callbacks=host_callbacks)
    235 # Some backends don't have `host_callbacks` option yet
    236 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    237 # to take in `host_callbacks`
--> 238 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

any help would be appreciated. thank you.

MichaelCoulter commented 2 days ago

i think zephyr is the only system that has cuda 12, so i think this is not expected to work on breeze.

MichaelCoulter commented 2 days ago

now i get this error running on zephyr (same populate command). is this an expected problem with memory usage? thanks.

/home/mcoulter/anaconda3/envs/spyglass2/lib/python3.9/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'hdmf-common' version 1.5.1 because version 1.6.0 is already loaded.
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
/home/mcoulter/anaconda3/envs/spyglass2/lib/python3.9/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'core' version 2.4.0 because version 2.6.0-alpha is already loaded.
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
/home/mcoulter/anaconda3/envs/spyglass2/lib/python3.9/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'hdmf-experimental' version 0.2.0 because version 0.3.0 is already loaded.
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
19-Nov-24 17:38:35 Fitting initial conditions...
19-Nov-24 17:38:35 Fitting discrete state transition
19-Nov-24 17:38:35 Fitting continuous state transition...
19-Nov-24 17:38:36 Fitting clusterless spikes...
19-Nov-24 17:38:39 Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
19-Nov-24 17:38:39 Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-11-19 17:38:40.027503: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.6.77). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Encoding models: 100%
39/39 [00:24<00:00, 1.84electrode/s]
19-Nov-24 17:39:11 Computing posterior...
19-Nov-24 17:39:11 Computing log likelihood...
2024-11-19 17:39:17.678492: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 10.38GiB (11144145813 bytes) by rematerialization; only reduced to 48.43GiB (52000000000 bytes), down from 48.43GiB (52000000000 bytes) originally
2024-11-19 17:39:17.753873: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 10.37GiB (11139167813 bytes) by rematerialization; only reduced to 48.43GiB (52000000004 bytes), down from 48.43GiB (52000000004 bytes) originally
2024-11-19 17:39:27.797967: W external/xla/xla/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 48.43GiB (rounded to 52000000000)requested by op 
2024-11-19 17:39:27.798525: W external/xla/xla/tsl/framework/bfc_allocator.cc:494] ***************************************************************************************_____________
E1119 17:39:27.798570 2086973 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 52000000000 bytes.
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In [4], line 21
     14 ClusterlessDecodingSelection.insert1(
     15     selection_key,
     16     skip_duplicates=True,
     17 )
     19 ClusterlessDecodingSelection & selection_key
---> 21 ClusterlessDecodingV1.populate(selection_key)

File ~/spyglass/src/spyglass/utils/dj_mixin.py:589, in SpyglassMixin.populate(self, *restrictions, **kwargs)
    587 if use_transact:  # Pass single-process populate to super
    588     kwargs["processes"] = processes
--> 589     return super().populate(*restrictions, **kwargs)
    590 else:  # No transaction protection, use bare make
    591     for key in keys:

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/datajoint/autopopulate.py:248, in AutoPopulate.populate(self, suppress_errors, return_exception_objects, reserve_jobs, order, limit, max_calls, display_progress, processes, make_kwargs, *restrictions)
    242 if processes == 1:
    243     for key in (
    244         tqdm(keys, desc=self.__class__.__name__)
    245         if display_progress
    246         else keys
    247     ):
--> 248         status = self._populate1(key, jobs, **populate_kwargs)
    249         if status is True:
    250             success_list.append(1)

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/datajoint/autopopulate.py:315, in AutoPopulate._populate1(self, key, jobs, suppress_errors, return_exception_objects, make_kwargs)
    313 self.__class__._allow_insert = True
    314 try:
--> 315     make(dict(key), **(make_kwargs or {}))
    316 except (KeyboardInterrupt, SystemExit, Exception) as error:
    317     try:

File ~/spyglass/src/spyglass/decoding/v1/clusterless.py:243, in ClusterlessDecodingV1.make(self, key)
    238             logger.warning(
    239                 f"Interval {interval_start}:{interval_end} is empty"
    240             )
    241             continue
    242         results.append(
--> 243             classifier.predict(
    244                 position_time=interval_time,
    245                 position=position_info.loc[interval_start:interval_end][
    246                     position_variable_names
    247                 ].to_numpy(),
    248                 spike_times=spike_times,
    249                 spike_waveform_features=spike_waveform_features,
    250                 time=interval_time,
    251                 **predict_kwargs,
    252             )
    253         )
    254     results = xr.concat(results, dim="intervals")
    256 # Save discrete transition and initial conditions

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/models/base.py:1645, in ClusterlessDetector.predict(self, spike_times, spike_waveform_features, time, position, position_time, is_missing, discrete_transition_covariate_data, cache_likelihood, n_chunks)
   1632 if discrete_transition_covariate_data is not None:
   1633     self.discrete_state_transitions_ = predict_discrete_state_transitions(
   1634         self.discrete_transition_design_matrix_,
   1635         self.discrete_transition_coefficients_,
   1636         discrete_transition_covariate_data,
   1637     )
   1639 (
   1640     acausal_posterior,
   1641     acausal_state_probabilities,
   1642     marginal_log_likelihood,
   1643     _,
   1644     _,
-> 1645 ) = self._predict(
   1646     time=time,
   1647     log_likelihood_args=(
   1648         position_time,
   1649         position,
   1650         spike_times,
   1651         spike_waveform_features,
   1652     ),
   1653     is_missing=is_missing,
   1654     cache_likelihood=cache_likelihood,
   1655     n_chunks=n_chunks,
   1656 )
   1658 return self._convert_results_to_xarray(
   1659     time,
   1660     acausal_posterior,
   1661     acausal_state_probabilities,
   1662     marginal_log_likelihood,
   1663 )

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/models/base.py:742, in _DetectorBase._predict(self, time, log_likelihood_args, is_missing, log_likelihoods, cache_likelihood, n_chunks)
    739 state_ind = self.state_ind_[is_track_interior]
    741 if self.discrete_state_transitions_.ndim == 2:
--> 742     return chunked_filter_smoother(
    743         time=time,
    744         state_ind=state_ind,
    745         initial_distribution=self.initial_conditions_[is_track_interior],
    746         transition_matrix=(
    747             self.continuous_state_transitions_[cross_is_track_interior]
    748             * self.discrete_state_transitions_[np.ix_(state_ind, state_ind)]
    749         ),
    750         log_likelihood_func=self.compute_log_likelihood,
    751         log_likelihood_args=log_likelihood_args,
    752         is_missing=is_missing,
    753         n_chunks=n_chunks,
    754         log_likelihoods=log_likelihoods,
    755         cache_log_likelihoods=cache_likelihood,
    756     )
    757 else:
    758     return chunked_filter_smoother_covariate_dependent(
    759         time=time,
    760         state_ind=state_ind,
   (...)
    771         cache_log_likelihoods=cache_likelihood,
    772     )

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/core.py:242, in chunked_filter_smoother(time, state_ind, initial_distribution, transition_matrix, log_likelihood_func, log_likelihood_args, is_missing, n_chunks, log_likelihoods, cache_log_likelihoods)
    240 else:
    241     is_missing_chunk = is_missing[time_inds] if is_missing is not None else None
--> 242     log_likelihood_chunk = log_likelihood_func(
    243         time[time_inds],
    244         *log_likelihood_args,
    245         is_missing=is_missing_chunk,
    246     )
    248 (
    249     (marginal_likelihood_chunk, predicted_probs_next),
    250     (causal_posterior_chunk, predicted_probs_chunk),
   (...)
    256     log_likelihoods=log_likelihood_chunk,
    257 )
    259 causal_posterior_chunk = np.asarray(causal_posterior_chunk)

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/models/base.py:1554, in ClusterlessDetector.compute_log_likelihood(self, time, position_time, position, spike_times, spike_waveform_features, is_missing)
   1547     log_likelihood = log_likelihood.at[:, is_state_bin].set(
   1548         predict_no_spike_log_likelihood(
   1549             time, spike_times, self.no_spike_rate
   1550         )
   1551     )
   1552 elif likelihood_name not in computed_likelihoods:
   1553     log_likelihood = log_likelihood.at[:, is_state_bin].set(
-> 1554         likelihood_func(
   1555             time,
   1556             position_time,
   1557             position,
   1558             spike_times,
   1559             spike_waveform_features,
   1560             **self.encoding_model_[likelihood_name[:2]],
   1561             is_local=obs.is_local,
   1562         )
   1563     )
   1564 else:
   1565     # Use previously computed likelihoods
   1566     previously_computed_bins = self.state_ind_[
   1567         self.is_track_interior_state_bins_
   1568     ] == computed_likelihoods.index(likelihood_name)

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/likelihoods/clusterless_kde.py:349, in predict_clusterless_kde_log_likelihood(time, position_time, position, spike_times, spike_waveform_features, occupancy, occupancy_model, gpi_models, encoding_spike_waveform_features, encoding_positions, environment, mean_rates, summed_ground_process_intensity, position_std, waveform_std, is_local, block_size, disable_progress_bar)
    346 n_time = len(time)
    348 if is_local:
--> 349     log_likelihood = compute_local_log_likelihood(
    350         time,
    351         position_time,
    352         position,
    353         spike_times,
    354         spike_waveform_features,
    355         occupancy_model,
    356         gpi_models,
    357         encoding_spike_waveform_features,
    358         encoding_positions,
    359         environment,
    360         mean_rates,
    361         position_std,
    362         waveform_std,
    363         block_size,
    364         disable_progress_bar,
    365     )
    366 else:
    367     is_track_interior = environment.is_track_interior_.ravel()

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/likelihoods/clusterless_kde.py:482, in compute_local_log_likelihood(time, position_time, position, spike_times, spike_waveform_features, occupancy_model, gpi_models, encoding_spike_waveform_features, encoding_positions, environment, mean_rates, position_std, waveform_std, block_size, disable_progress_bar)
    478 # Need to interpolate position
    479 interpolated_position = get_position_at_time(
    480     position_time, position, time, environment
    481 )
--> 482 occupancy = occupancy_model.predict(interpolated_position)
    484 n_time = len(time)
    485 log_likelihood = jnp.zeros((n_time,))

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/likelihoods/common.py:208, in KDEModel.predict(self, eval_points)
    199 std = (
    200     jnp.array([self.std] * eval_points.shape[1])
    201     if isinstance(self.std, (int, float))
    202     else self.std
    203 )
    204 block_size = (
    205     eval_points.shape[0] if self.block_size is None else self.block_size
    206 )
--> 208 return block_kde(eval_points, self.samples_, std, block_size)

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/likelihoods/common.py:155, in block_kde(eval_points, samples, std, block_size)
    151 for start_ind in range(0, n_eval_points, block_size):
    152     block_inds = slice(start_ind, start_ind + block_size)
    153     density = jax.lax.dynamic_update_slice(
    154         density,
--> 155         kde(eval_points[block_inds], samples, std),
    156         (start_ind,),
    157     )
    159 return density

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/likelihoods/common.py:118, in kde(eval_points, samples, std)
    115 distance = jnp.ones((samples.shape[0], eval_points.shape[0]))
    117 for dim_eval_points, dim_samples, dim_std in zip(eval_points.T, samples.T, std):
--> 118     distance *= gaussian_pdf(
    119         jnp.expand_dims(dim_eval_points, axis=0),
    120         jnp.expand_dims(dim_samples, axis=1),
    121         dim_std,
    122     )
    123 return jnp.mean(distance, axis=0)

    [... skipping hidden 10 frame]

File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py:1253, in ExecuteReplicated.__call__(self, *args)
   1251   self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1252 else:
-> 1253   results = self.xla_executable.execute_sharded(input_bufs)
   1255 if dispatch.needs_check_special():
   1256   out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 52000000000 bytes.
MichaelCoulter commented 2 days ago

this is the key that triggers this error

nwb_file_name = 'CH65_20211204_.nwb'
selection_key = {
    "waveform_features_group_name": "CH65_12_04_all_tet",
    "position_group_name": "CH65_12_04",
    "decoding_param_name": 'CH65_1204_nonlocal',
    "nwb_file_name": nwb_file_name,
    "encoding_interval": "CH65_12_04_01",
    "decoding_interval": "CH65_12_04_01",
    "estimate_decoding_params": False,
}

ClusterlessDecodingSelection.insert1(
    selection_key,
    skip_duplicates=True,
)

ClusterlessDecodingSelection & selection_key

ClusterlessDecodingV1.populate(selection_key)
edeno commented 1 day ago

If you are running out of resources, you should switch to a different GPU: e.g.

import jax

device_id = 2

device = jax.devices()[device_id]
jax.config.update("jax_default_device", device)