a-antoniades / Neuroformer

MIT License
30 stars 3 forks source link

How to convert spike ID or times to rates #7

Closed PPWangyc closed 1 month ago

PPWangyc commented 1 month ago

Hi,

I'm currently working with the generate_spikes function in simulation.py and encountering some challenges about converting output neuron IDs to original spike rates [Nuerons X Time]. Here's a brief overview of the issue:

After executing the function, I receive a data dictionary structured as follows:

data = {'ID', 'true', 'time', 'Intervals'}
print(data.keys())
print(len(data['ID']))
print(len(data['time']))
print(len(data['dt'])) 
print(len(data['Trial']))
print(len(data['Interval']))
print(len(data['true']))

Output:

4923
7920
4923
4923
4923
7920

As you can see, the lengths of time and Interval fields do not match the others, making it difficult to utilize the get_rates function in analysis.py effectively. The definition of get_rates is as follows:

def get_rates(df, ids, intervals, interval='Interval')

Could you please assist me in aligning these data fields appropriately for analysis? Thank you very much for your help!

a-antoniades commented 1 month ago

This is likely due to the model predicting no spikes in some of the intervals that are present in true. Perhaps there is a better way to handle this in def generate_spikes.

One suggestion would be to update get_rates to handle intervals in which no spikes were predicted gracefully. Maybe try something similar this:

def get_rates(df, ids, intervals, interval='Interval'):
    intervals = np.array(intervals)
    df = df.groupby(['ID', interval]).count().unstack(fill_value=0).stack()['Time']

    def set_rates(df, id, intervals):
        rates = np.zeros_like(intervals, dtype=np.float32)
        if id not in df.index:
            return rates
        else:
            df = df[id]
            for n, i in enumerate(intervals):
                if i in df.index:
                    rates[n] = df.loc[i]
            return rates

    rates = dict()
    for id in ids:
        rates[id] = set_rates(df, id, intervals)
    return rates

It seems like you are under-predicting by a good amount here. You could look at your sampling and temperature/top_p/top_k parameters to improve the behavior of the generation.

PPWangyc commented 1 month ago

Thanks for you help!