fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.22k stars 233 forks source link

运行文档 “降低内存占用” 部分未成功 #556

Open DayBeha opened 2 weeks ago

DayBeha commented 2 weeks ago

在本地运行文档降低内存占用 如下代码时:

import torch

def tensor_memory(x: torch.Tensor):
    return x.element_size() * x.numel()

N = 1 << 10
spike = torch.randint(0, 2, [N]).float()

print('float32 size =', tensor_memory(spike))
print('torch.bool size =', tensor_memory(spike.to(torch.bool)))

from spikingjelly.activation_based import tensor_cache

spike_b, s_dtype, s_shape, s_padding = tensor_cache.float_spike_to_bool(spike)

print('bool size =', tensor_memory(spike_b))

spike_recover = tensor_cache.bool_spike_to_float(spike_b, s_dtype, s_shape, s_padding)

print('spike == spike_recover?', torch.equal(spike, spike_recover))

如下报错:


AssertionError                            Traceback (most recent call last)
Cell In[24], line 11
      8 N = 1 << 20
      9 spike = (torch.rand([N]) > 0.8).float()
---> 11 spike_b, s_dtype, s_shape, s_padding = tensor_cache.float_spike_to_bool(spike)
     13 arr = spike_b.numpy()
     15 compressed_arr = zlib.compress(arr.tobytes())

File ~/anaconda3/envs/spikingjelly/lib/python3.10/site-packages/spikingjelly/activation_based/tensor_cache.py:123, in float_spike_to_bool(spike)
    115     kernel_args = [spike, spike_b, numel]
    116     kernel = cupy.RawKernel(
    117         kernel_codes,
    118         kernel_name,
    119         options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend
    120     )
    121     kernel(
    122         (blocks,), (configure.cuda_threads,),
--> 123         cuda_utils.wrap_args_to_raw_kernel(
    124             device_id,
    125             *kernel_args
    126         )
    127     )
    128 return spike_b, s_dtype, s_shape, s_padding

File ~/anaconda3/envs/spikingjelly/lib/python3.10/site-packages/spikingjelly/activation_based/cuda_utils.py:249, in wrap_args_to_raw_kernel(device, *args)
    246     ret_list.append(item.data_ptr())
    248 elif isinstance(item, cupy.ndarray):
--> 249     assert item.device.id == device
    250     assert item.flags['C_CONTIGUOUS']
    251     ret_list.append(item)

AssertionError: 

完全跟着教程走的。 请问可能是什么导致该如何解决呢?