fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.22k stars 233 forks source link

Mistake in SHD dataset event to frame integration #507

Closed cby120 closed 3 months ago

cby120 commented 3 months ago

For faster response

Features Developers
Others @fangwei123456

Issue type

SpikingJelly version

0.0.0.0.15

Description

In module spikingjelly.datasets.shd, there is a small logic mistake when integrating spike events into frames, in function integrate_events_by_fixed_duration_shd (line72). A simple test is provided in code to reproduce part, as if integrate a event set lengthed T by fixed duration dt, the frame count n should be floor(T / dt) + 1.

When deciding the index span of one frame, the starting time t_l is recorded by the time which is represented by the left event index, which makes t_l of each next frame dependent on the next spike event, instead of directly at when the last frame stops. In other words, there are intervals between frames, whose size is decided by the time distance between the last spike of the last frame and the first spike of the next frame.

shd frame integration

A possible suggestion is record t_l as 0 and increase duration in each loop, which assures the frame integration is performed according to fixed duration.

Minimal code to reproduce the error/bug

from spikingjelly.datasets.shd import SpikingHeidelbergDigits
dt = 10
shd_event = SpikingHeidelbergDigits("./data/origin/")
shd_frame = SpikingHeidelbergDigits("./data/origin/", data_type="frame", duration=dt)
T = shd_event[0][0]['t'].max() * 1000  # cast to unit ms
n = shd_frame[0][0].shape[0]  # frame count
print(n, int(T / dt) + 1)
print(n == int(T / dt) + 1)

>>> 84 74
>>> False
fangwei123456 commented 3 months ago

Hi, do you think the following codes are the correct implementation?

def integrate_events_by_fixed_duration_shd(events: Dict, duration: int, W: int) -> np.ndarray:

    x = events['x']
    t = 1000*events['t']
    t = t - t[0]

    N = t.size

    frames_num = int(math.ceil(t[-1] / duration))
    frames = np.zeros([frames_num, W])
    frame_index = t // duration

    frames = []
    left = 0
    right = 0
    while True:
        while True:
            if right == N or frame_index[right] != frame_index[left]:
                break
            else:
                right += 1
        # integrate from index [left, right)
        frames.append(np.expand_dims(integrate_events_segment_to_frame_shd(x, W, left, right), 0))

        left = right

        if right == N:
            return frames
cby120 commented 3 months ago

Thank you for responding. This version fixed exactly the issue. I have some suggestion you might consider, replacing the loop part:

...

frame_index = t // duration

n_frames = frame_index.max()
frames = list()
for k in range(n_frames):
    ind = np.where(frame_index == k)[0]
    frames.append(np.expand_dims(integrate_events_segment_to_frame_shd(x, W, ind[0], ind[-1]), 0))

...

My consideration is that, using numpy ufunc simplifies the code structure and could probably improve the efficiency.

One more thing, I noticed the white space before the first spike is trimmed (t = t - t[0]), i.e. each frame sequence will always start with non-zero frame. I wonder if this is necessary, since the spacing might carries information or certain features in the data.

Looking forward to your feedback !

fangwei123456 commented 3 months ago

My consideration is that, using numpy ufunc simplifies the code structure and could probably improve the efficiency.

Thanks, I will use this method.

each frame sequence will always start with non-zero frame.

I think zero frames occur when the camera is just power on, and zero frames do not contain information (in most cases, the network will output zero when input is zero).

fangwei123456 commented 3 months ago
def integrate_events_by_fixed_duration_shd(events: Dict, duration: int, W: int) -> np.ndarray:

    x = events['x']
    t = 1000*events['t']
    t = t - t[0]

    N = t.size

    frames_num = int(math.ceil(t[-1] / duration))
    frames = np.zeros([frames_num, W])
    frame_index = t // duration
    left = 0

    for i in range(frames_num):
        if i + 1 == N:
            right = N
        else:
            right = np.searchsorted(frame_index, i + 1, side='left')
        frames[i] = integrate_events_segment_to_frame_shd(x, W, left, right)
        left = right

    return frames

I think a binary search is faster?

cby120 commented 3 months ago

I think a binary search is faster?

I agree. I'll do some benchmark for more solid support.

cby120 commented 3 months ago

The result supports binary search.

>>> ===> light data load
>>> origin    per call:  0.20ms
>>> where     per call:  0.24ms
>>> bin_count per call:  0.18ms

>>> ===> heavy data load
>>> origin    per call:  2.40ms
>>> where     per call:  1.41ms
>>> bin_count per call:  1.06ms

codes are below for inspection:

from typing import Callable
import numpy as np
from time import time
import math
from spikingjelly.datasets.shd import integrate_events_segment_to_frame_shd

def origin(events: dict, duration: int, W: int) -> np.ndarray:

    x = events["x"]
    t = 1000 * events["t"]
    t = t - t[0]

    N = t.size

    frames_num = int(math.ceil(t[-1] / duration))
    frames = np.zeros([frames_num, W])
    frame_index = t // duration

    frames = []
    left = 0
    right = 0
    while True:
        while True:
            if right == N or frame_index[right] != frame_index[left]:
                break
            else:
                right += 1
        frames.append(
            np.expand_dims(integrate_events_segment_to_frame_shd(x, W, left, right), 0)
        )

        left = right

        if right == N:
            return frames

def where(events: dict, duration: int, W: int) -> np.ndarray:

    x = events["x"]
    t = 1000 * events["t"]
    t = t - t[0]

    N = t.size

    frames_num = int(math.ceil(t[-1] / duration))
    frames = np.zeros([frames_num, W])
    frame_index = t // duration

    frames = []
    for k in range(frames_num):
        ind = np.where(frame_index == k)[0]
        frames.append(
            np.expand_dims(
                integrate_events_segment_to_frame_shd(x, W, ind[0], ind[-1]), 0
            )
        )

    return frames

def bin_count(events: dict, duration: int, W: int) -> np.ndarray:

    x = events["x"]
    t = 1000 * events["t"]
    t = t - t[0]

    N = t.size

    frames_num = int(math.ceil(t[-1] / duration))
    frames = np.zeros([frames_num, W])
    frame_index = t // duration
    left = 0

    for i in range(frames_num):
        if i + 1 == N:
            right = N
        else:
            right = np.searchsorted(frame_index, i + 1, side="left")
        frames[i] = integrate_events_segment_to_frame_shd(x, W, left, right)
        left = right

    return frames

def benchmark(func: Callable, data: dict, n_loops: int, duration: int, W: int):
    t0 = time()
    for i in range(n_loops):
        # using copy since ndarray isn't readonly
        # sometimes got exception, guess because inplace operation
        # no difference in result
        func(data.copy(), duration, W)
    t = time() - t0
    print(f"{func.__name__:<9s} per call: {t / n_loops * 1000: 5.2f}ms")

def build_task(n_loops: int, length: int, W: int, duration: int):
    t = np.random.rand(length)
    t.sort()
    x = np.random.randint(0, W, length)
    events = {"x": x, "t": t}
    for f in (origin, where, bin_count):
        benchmark(f, events, n_loops, duration, W)

# light data load
print("===> light data load")
build_task(n_loops=10**5, length=100, W=20, duration=50)

# heavy data load
print("===> heavy data load")
build_task(n_loops=10**4, length=10000, W=1000, duration=20)
fangwei123456 commented 3 months ago

HI, I think the binseary search can be accelerated by replacing right = np.searchsorted(frame_index, i + 1, side="left") to right = left + np.searchsorted(frame_index[left:], i + 1, side="left").

cby120 commented 3 months ago

One more small improvement. If the if statement handling last-frame situation is moved out from for loop:

for i in range(frames_num - 1):
    right = np.searchsorted(frame_index, i + 1, side="left")
    frames[i] = integrate_events_segment_to_frame_shd(x, W, left, right)
    left = right
frames[frames_num - 1] = integrate_events_segment_to_frame_shd(x, W, left, N)

I'll test these two modifications.

cby120 commented 3 months ago

bin_count is new basis version (#507-c5) bin_count2 is with mod #507-c9 bin_count3 is with mod #507-c9 and #507-c8

The result is as follows:

===> light data load
bin_count  per call:  0.1750ms
bin_count2 per call:  0.1722ms
bin_count3 per call:  0.1781ms

===> heavy data load
bin_count  per call:  1.0422ms
bin_count2 per call:  1.0391ms
bin_count3 per call:  1.0516ms

It seems slicing operation has a considerable overhead. The improvement of #507-c9 is also not remarkable.