paninski-lab / yass

YASS: Yet Another Spike Sorter
https://github.com/paninski-lab/yass/wiki
Apache License 2.0
63 stars 15 forks source link

NN detection: size mismatch for out.weight: copying a param with shape torch.Size([1, 56]) from checkpoint, the shape in current model is torch.Size([1, 128]). #298

Closed llobetv closed 4 years ago

llobetv commented 4 years ago

Hi,

The reason I didn't use nn was that i encounter this bug when i use it on my data... :

File "/users/nsr/llobet/anaconda3/envs/yass/bin/yass", line 10, in sys.exit(cli()) File "/users/nsr/llobet/anaconda3/envs/yass/lib/python3.6/site-packages/click/core.py", line 764, in call return self.main(args, kwargs) File "/users/nsr/llobet/anaconda3/envs/yass/lib/python3.6/site-packages/click/core.py", line 717, in main rv = self.invoke(ctx) File "/users/nsr/llobet/anaconda3/envs/yass/lib/python3.6/site-packages/click/core.py", line 1137, in invoke return _process_result(sub_ctx.command.invoke(sub_ctx)) File "/users/nsr/llobet/anaconda3/envs/yass/lib/python3.6/site-packages/click/core.py", line 956, in invoke return ctx.invoke(self.callback, ctx.params) File "/users/nsr/llobet/anaconda3/envs/yass/lib/python3.6/site-packages/click/core.py", line 555, in invoke return callback(args, **kwargs) File "/users/nsr/llobet/anaconda3/envs/yass/lib/python3.6/site-packages/yass/command_line.py", line 71, in sort calculate_rf=calculate_rf, visualize=visualize)#, File "/users/nsr/llobet/anaconda3/envs/yass/lib/python3.6/site-packages/yass/pipeline.py", line 137, in run run_chunk_sec = CONFIG.clustering_chunk) File "/users/nsr/llobet/anaconda3/envs/yass/lib/python3.6/site-packages/yass/pipeline.py", line 232, in initial_block run_chunk_sec=run_chunk_sec) File "/users/nsr/llobet/anaconda3/envs/yass/lib/python3.6/site-packages/yass/detect/run.py", line 107, in run run_chunk_sec=run_chunk_sec) File "/users/nsr/llobet/anaconda3/envs/yass/lib/python3.6/site-packages/yass/detect/run.py", line 134, in run_neural_network detector.load(CONFIG.neuralnetwork.detect.filename) File "/users/nsr/llobet/anaconda3/envs/yass/lib/python3.6/site-packages/yass/neuralnetwork/model_detector.py", line 142, in load self.load_state_dict(checkpoint) File "/users/nsr/llobet/anaconda3/envs/yass/lib/python3.6/site-packages/torch/nn/modules/module.py", line 777, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Detect: size mismatch for out.weight: copying a param with shape torch.Size([1, 56]) from checkpoint, the shape in current model is torch.Size([1, 128]).

catubc commented 4 years ago

This crash may be related to incompatible NN spike width. Did you retrain the NNs with different spike width or change the default CONFIG file spike width?

llobetv commented 4 years ago

Hi, i didn't retrain NN nor change config file spike width (still 3ms)...

NN for detection woks for trace in sample/10chan, independently of spike width, default is 3.

Note that spike width is 4 ms in exemples/config_sample.yaml and 3ms in sample/10chan/config.yaml

DradeAW commented 4 years ago

Hi,

I am working with llobetv, and I dug dipper into this issue.

By diving into the yass and torch code, I found that the [1, 56] size comes from "detect.pt" which is the detect filename in config.yaml.

On the other hand, the [1, 128] (or in my case with my dataset [1, 88]) comes from the neighbors: in yass/neuralnetwork/model_detector.py, this 128 or 88 is feat3*n_neigh. feat3 is the 3rd n_filters in config.yaml-->neuralnetwork-->detect. n_neigh is calculated from the spatial_radius (also specified in config.yaml).

So the error comes from the fact that we use a bigger electrode (with 64 channels), and we need to reduce the spatial radius (in my case from 70 to 40) to not have this mismatch. My guess is that if we want to change the number of neighbors, we would need to retrain the neural network and make another file (although I don't understand why it works like that).

This solved the issue, but I now have another issue : Traceback (most recent call last): File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/multiprocessing/pool.py", line 119, in worker result = (True, func(*args, *kwds)) File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar return list(map(args)) File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/parmap/parmap.py", line 105, in _func_star_single **func_item_args[3]) File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/yass/merge/merge.py", line 393, in merge_templates_parallel self.temporal_whitener).transpose(0,2,1) ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 91 is different from 61) """

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/bin/yass", line 8, in sys.exit(cli()) File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/click/core.py", line 764, in call return self.main(args, kwargs) File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/click/core.py", line 717, in main rv = self.invoke(ctx) File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/click/core.py", line 1137, in invoke return _process_result(sub_ctx.command.invoke(sub_ctx)) File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/click/core.py", line 956, in invoke return ctx.invoke(self.callback, ctx.params) File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/click/core.py", line 555, in invoke return callback(args, **kwargs) File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/yass/command_line.py", line 71, in sort calculate_rf=calculate_rf, visualize=visualize)#, File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/yass/pipeline.py", line 164, in run run_chunk_sec = CONFIG.clustering_chunk) File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/yass/pipeline.py", line 437, in pre_final_deconv residual_dtype) File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/yass/merge/run.py", line 74, in run tm.get_merge_pairs() File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/yass/merge/merge.py", line 275, in get_merge_pairs pm_pbar=True) File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/parmap/parmap.py", line 304, in map return _map_or_starmap(function, iterable, args, kwargs, "map") File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/site-packages/parmap/parmap.py", line 282, in _map_or_starmap output = result.get() File "/users/nsr/wyngaard/miniconda3/envs/eyeblink3.6/lib/python3.6/multiprocessing/pool.py", line 644, in get raise self._value ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 91 is different from 61)

I will dig deeper into this error and try to see if I can solve it.

DradeAW commented 4 years ago

Nevermind, I just needed to delete the "tmp" folder.

By deleting the tmp folder and adjusting the spatial radius, it works fine!