Closed MichaelCoulter closed 1 day ago
i think zephyr is the only system that has cuda 12, so i think this is not expected to work on breeze.
now i get this error running on zephyr (same populate command). is this an expected problem with memory usage? thanks.
/home/mcoulter/anaconda3/envs/spyglass2/lib/python3.9/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'hdmf-common' version 1.5.1 because version 1.6.0 is already loaded.
warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
/home/mcoulter/anaconda3/envs/spyglass2/lib/python3.9/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'core' version 2.4.0 because version 2.6.0-alpha is already loaded.
warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
/home/mcoulter/anaconda3/envs/spyglass2/lib/python3.9/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'hdmf-experimental' version 0.2.0 because version 0.3.0 is already loaded.
warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
19-Nov-24 17:38:35 Fitting initial conditions...
19-Nov-24 17:38:35 Fitting discrete state transition
19-Nov-24 17:38:35 Fitting continuous state transition...
19-Nov-24 17:38:36 Fitting clusterless spikes...
19-Nov-24 17:38:39 Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
19-Nov-24 17:38:39 Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-11-19 17:38:40.027503: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.6.77). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Encoding models: 100%
39/39 [00:24<00:00, 1.84electrode/s]
19-Nov-24 17:39:11 Computing posterior...
19-Nov-24 17:39:11 Computing log likelihood...
2024-11-19 17:39:17.678492: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 10.38GiB (11144145813 bytes) by rematerialization; only reduced to 48.43GiB (52000000000 bytes), down from 48.43GiB (52000000000 bytes) originally
2024-11-19 17:39:17.753873: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 10.37GiB (11139167813 bytes) by rematerialization; only reduced to 48.43GiB (52000000004 bytes), down from 48.43GiB (52000000004 bytes) originally
2024-11-19 17:39:27.797967: W external/xla/xla/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 48.43GiB (rounded to 52000000000)requested by op
2024-11-19 17:39:27.798525: W external/xla/xla/tsl/framework/bfc_allocator.cc:494] ***************************************************************************************_____________
E1119 17:39:27.798570 2086973 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 52000000000 bytes.
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
Cell In [4], line 21
14 ClusterlessDecodingSelection.insert1(
15 selection_key,
16 skip_duplicates=True,
17 )
19 ClusterlessDecodingSelection & selection_key
---> 21 ClusterlessDecodingV1.populate(selection_key)
File ~/spyglass/src/spyglass/utils/dj_mixin.py:589, in SpyglassMixin.populate(self, *restrictions, **kwargs)
587 if use_transact: # Pass single-process populate to super
588 kwargs["processes"] = processes
--> 589 return super().populate(*restrictions, **kwargs)
590 else: # No transaction protection, use bare make
591 for key in keys:
File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/datajoint/autopopulate.py:248, in AutoPopulate.populate(self, suppress_errors, return_exception_objects, reserve_jobs, order, limit, max_calls, display_progress, processes, make_kwargs, *restrictions)
242 if processes == 1:
243 for key in (
244 tqdm(keys, desc=self.__class__.__name__)
245 if display_progress
246 else keys
247 ):
--> 248 status = self._populate1(key, jobs, **populate_kwargs)
249 if status is True:
250 success_list.append(1)
File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/datajoint/autopopulate.py:315, in AutoPopulate._populate1(self, key, jobs, suppress_errors, return_exception_objects, make_kwargs)
313 self.__class__._allow_insert = True
314 try:
--> 315 make(dict(key), **(make_kwargs or {}))
316 except (KeyboardInterrupt, SystemExit, Exception) as error:
317 try:
File ~/spyglass/src/spyglass/decoding/v1/clusterless.py:243, in ClusterlessDecodingV1.make(self, key)
238 logger.warning(
239 f"Interval {interval_start}:{interval_end} is empty"
240 )
241 continue
242 results.append(
--> 243 classifier.predict(
244 position_time=interval_time,
245 position=position_info.loc[interval_start:interval_end][
246 position_variable_names
247 ].to_numpy(),
248 spike_times=spike_times,
249 spike_waveform_features=spike_waveform_features,
250 time=interval_time,
251 **predict_kwargs,
252 )
253 )
254 results = xr.concat(results, dim="intervals")
256 # Save discrete transition and initial conditions
File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/models/base.py:1645, in ClusterlessDetector.predict(self, spike_times, spike_waveform_features, time, position, position_time, is_missing, discrete_transition_covariate_data, cache_likelihood, n_chunks)
1632 if discrete_transition_covariate_data is not None:
1633 self.discrete_state_transitions_ = predict_discrete_state_transitions(
1634 self.discrete_transition_design_matrix_,
1635 self.discrete_transition_coefficients_,
1636 discrete_transition_covariate_data,
1637 )
1639 (
1640 acausal_posterior,
1641 acausal_state_probabilities,
1642 marginal_log_likelihood,
1643 _,
1644 _,
-> 1645 ) = self._predict(
1646 time=time,
1647 log_likelihood_args=(
1648 position_time,
1649 position,
1650 spike_times,
1651 spike_waveform_features,
1652 ),
1653 is_missing=is_missing,
1654 cache_likelihood=cache_likelihood,
1655 n_chunks=n_chunks,
1656 )
1658 return self._convert_results_to_xarray(
1659 time,
1660 acausal_posterior,
1661 acausal_state_probabilities,
1662 marginal_log_likelihood,
1663 )
File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/models/base.py:742, in _DetectorBase._predict(self, time, log_likelihood_args, is_missing, log_likelihoods, cache_likelihood, n_chunks)
739 state_ind = self.state_ind_[is_track_interior]
741 if self.discrete_state_transitions_.ndim == 2:
--> 742 return chunked_filter_smoother(
743 time=time,
744 state_ind=state_ind,
745 initial_distribution=self.initial_conditions_[is_track_interior],
746 transition_matrix=(
747 self.continuous_state_transitions_[cross_is_track_interior]
748 * self.discrete_state_transitions_[np.ix_(state_ind, state_ind)]
749 ),
750 log_likelihood_func=self.compute_log_likelihood,
751 log_likelihood_args=log_likelihood_args,
752 is_missing=is_missing,
753 n_chunks=n_chunks,
754 log_likelihoods=log_likelihoods,
755 cache_log_likelihoods=cache_likelihood,
756 )
757 else:
758 return chunked_filter_smoother_covariate_dependent(
759 time=time,
760 state_ind=state_ind,
(...)
771 cache_log_likelihoods=cache_likelihood,
772 )
File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/core.py:242, in chunked_filter_smoother(time, state_ind, initial_distribution, transition_matrix, log_likelihood_func, log_likelihood_args, is_missing, n_chunks, log_likelihoods, cache_log_likelihoods)
240 else:
241 is_missing_chunk = is_missing[time_inds] if is_missing is not None else None
--> 242 log_likelihood_chunk = log_likelihood_func(
243 time[time_inds],
244 *log_likelihood_args,
245 is_missing=is_missing_chunk,
246 )
248 (
249 (marginal_likelihood_chunk, predicted_probs_next),
250 (causal_posterior_chunk, predicted_probs_chunk),
(...)
256 log_likelihoods=log_likelihood_chunk,
257 )
259 causal_posterior_chunk = np.asarray(causal_posterior_chunk)
File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/models/base.py:1554, in ClusterlessDetector.compute_log_likelihood(self, time, position_time, position, spike_times, spike_waveform_features, is_missing)
1547 log_likelihood = log_likelihood.at[:, is_state_bin].set(
1548 predict_no_spike_log_likelihood(
1549 time, spike_times, self.no_spike_rate
1550 )
1551 )
1552 elif likelihood_name not in computed_likelihoods:
1553 log_likelihood = log_likelihood.at[:, is_state_bin].set(
-> 1554 likelihood_func(
1555 time,
1556 position_time,
1557 position,
1558 spike_times,
1559 spike_waveform_features,
1560 **self.encoding_model_[likelihood_name[:2]],
1561 is_local=obs.is_local,
1562 )
1563 )
1564 else:
1565 # Use previously computed likelihoods
1566 previously_computed_bins = self.state_ind_[
1567 self.is_track_interior_state_bins_
1568 ] == computed_likelihoods.index(likelihood_name)
File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/likelihoods/clusterless_kde.py:349, in predict_clusterless_kde_log_likelihood(time, position_time, position, spike_times, spike_waveform_features, occupancy, occupancy_model, gpi_models, encoding_spike_waveform_features, encoding_positions, environment, mean_rates, summed_ground_process_intensity, position_std, waveform_std, is_local, block_size, disable_progress_bar)
346 n_time = len(time)
348 if is_local:
--> 349 log_likelihood = compute_local_log_likelihood(
350 time,
351 position_time,
352 position,
353 spike_times,
354 spike_waveform_features,
355 occupancy_model,
356 gpi_models,
357 encoding_spike_waveform_features,
358 encoding_positions,
359 environment,
360 mean_rates,
361 position_std,
362 waveform_std,
363 block_size,
364 disable_progress_bar,
365 )
366 else:
367 is_track_interior = environment.is_track_interior_.ravel()
File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/likelihoods/clusterless_kde.py:482, in compute_local_log_likelihood(time, position_time, position, spike_times, spike_waveform_features, occupancy_model, gpi_models, encoding_spike_waveform_features, encoding_positions, environment, mean_rates, position_std, waveform_std, block_size, disable_progress_bar)
478 # Need to interpolate position
479 interpolated_position = get_position_at_time(
480 position_time, position, time, environment
481 )
--> 482 occupancy = occupancy_model.predict(interpolated_position)
484 n_time = len(time)
485 log_likelihood = jnp.zeros((n_time,))
File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/likelihoods/common.py:208, in KDEModel.predict(self, eval_points)
199 std = (
200 jnp.array([self.std] * eval_points.shape[1])
201 if isinstance(self.std, (int, float))
202 else self.std
203 )
204 block_size = (
205 eval_points.shape[0] if self.block_size is None else self.block_size
206 )
--> 208 return block_kde(eval_points, self.samples_, std, block_size)
File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/likelihoods/common.py:155, in block_kde(eval_points, samples, std, block_size)
151 for start_ind in range(0, n_eval_points, block_size):
152 block_inds = slice(start_ind, start_ind + block_size)
153 density = jax.lax.dynamic_update_slice(
154 density,
--> 155 kde(eval_points[block_inds], samples, std),
156 (start_ind,),
157 )
159 return density
File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/non_local_detector/likelihoods/common.py:118, in kde(eval_points, samples, std)
115 distance = jnp.ones((samples.shape[0], eval_points.shape[0]))
117 for dim_eval_points, dim_samples, dim_std in zip(eval_points.T, samples.T, std):
--> 118 distance *= gaussian_pdf(
119 jnp.expand_dims(dim_eval_points, axis=0),
120 jnp.expand_dims(dim_samples, axis=1),
121 dim_std,
122 )
123 return jnp.mean(distance, axis=0)
[... skipping hidden 10 frame]
File ~/anaconda3/envs/spyglass2/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py:1253, in ExecuteReplicated.__call__(self, *args)
1251 self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
1252 else:
-> 1253 results = self.xla_executable.execute_sharded(input_bufs)
1255 if dispatch.needs_check_special():
1256 out_arrays = results.disassemble_into_single_device_arrays()
XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 52000000000 bytes.
this is the key that triggers this error
nwb_file_name = 'CH65_20211204_.nwb'
selection_key = {
"waveform_features_group_name": "CH65_12_04_all_tet",
"position_group_name": "CH65_12_04",
"decoding_param_name": 'CH65_1204_nonlocal',
"nwb_file_name": nwb_file_name,
"encoding_interval": "CH65_12_04_01",
"decoding_interval": "CH65_12_04_01",
"estimate_decoding_params": False,
}
ClusterlessDecodingSelection.insert1(
selection_key,
skip_duplicates=True,
)
ClusterlessDecodingSelection & selection_key
ClusterlessDecodingV1.populate(selection_key)
If you are running out of resources, you should switch to a different GPU: e.g.
import jax
device_id = 2
device = jax.devices()[device_id]
jax.config.update("jax_default_device", device)
i installed non-local-detector[gpu] via pip and tried to run ClusterlessDecodingV1.populate(selection_key)
then i got this error. i am using breeze.
any help would be appreciated. thank you.