nengo / nengo-extras

Extra utilities and add-ons for Nengo
https://www.nengo.ai/nengo-extras
Other
5 stars 8 forks source link

CudaConvnetNetwork on PyTorch model #95

Open MariamAlz opened 1 year ago

MariamAlz commented 1 year ago

I'm trying to convert a pretrained PyTorch GAN model (generator) to spiking format using Nengo and NengoDL. I'm following the "CIFAR-10 classifier with a spiking CNN" tutorial from https://www.nengo.ai/nengo-extras/v0.5.0/examples/cuda_convnet/cifar10_spiking_cnn.html where I am using my model instead.

After running this line of code:

ccnet = CudaConvnetNetwork(netG_A2B_pkl, synapse=nengo.synapses.Alpha(0.005))

I get this error:

TypeError: 'Generator' object is not subscriptable

I'm using the netG_A2B model from here: https://github.com/yz-wang/Cycle-SNSPGAN

Any help or recommendations will be appreciated!

hunse commented 1 year ago

CudaConvnetNetwork is a) pretty old code, and b) designed for importing networks created with the (now defunct) CudaConvnet library. It won't work for something from PyTorch. However, you could take inspiration from its design in writing your own class that can import PyTorch networks to Nengo. Unfortunately, we don't have a lot of support for PyTorch currently in the Nengo ecosystem (NengoDL has a "Keras converter" for importing from TensorFlow/Keras).