LBHB / NEMS0

THIS VERSION OF NEMS IS NO LONGER SUPPORTED. PLEASE USE THE NEW NEMS REPOSITORY OR INSTALL NEMS_DB TO GET NEMS0 SUPPORT.
GNU General Public License v3.0
8 stars 4 forks source link

How to train STP model on numpy arrays #254

Open gavinmischler opened 2 years ago

gavinmischler commented 2 years ago

Hi, I am trying to train a STRF+STP model (like the local STP model in this paper) on some ECoG data. I've been trying to follow along with some of the tutorial notebooks on the documentation but I had several questions I was hoping you could help me with.

Here is a basic outline of what I want to do in a simple case:

What I have done so far To start off with, I want to put my data into Signal form. It appears that pythonfrom nems.signal import Signal does not work, so I just used RasterizedSignal instead. I took the 9 training data trials and concatenated them in time (in numpy), then put it into a RasterizedSignal with 50 channels and 9 epochs, and then I called python.save(filepath) to save this training data. I wasn't sure how to format the stimuli, so I also saved it as a Rasterized signal in the same way, though this time with 32 channels. I also did the same with the 1 test trial for both responses and stimulus, though it only has 1 epoch.

So, now I have a directory called train_data/ which contains 6 files: -- Train-00.resp.csv -- Train-00.resp.epoch.csv -- Train-00.resp.json -- Train-00.stim.csv -- Train-00.stim.epoch.csv -- Train-00.stim.json

and a directory called test_data/ which contains 6 files: -- Test-00.resp.csv -- Test-00.resp.epoch.csv -- Test-00.resp.json -- Test-00.stim.csv -- Test-00.stim.epoch.csv -- Test-00.stim.json

Defining the model spec

Now, I believe my next step is to define a model using the modelspecname and then fit it using xforms. This is where I am a little stuck.

Would this modelspecname produce a STRF + local STP model similar to the one used in this paper, but with a 400 ms kernel size (40 samples) for my 32-channel spectrogram stimuli?

modelspecname = 'stp.1-fir.40x32-lvl.1-dexp.1'

Training with xforms

Once I have the modelspecname, my next step will be to train it with xforms, similar to the demo_2p.ipynb notebook. However, that notebook seems to use a load_polley_data command. How do I train a model by simply inputting the directory "train_data/", or maybe by the stimfile and respfile names?

Thanks so much in advance for any help you can give, Gavin

jacobpennington commented 2 years ago

Hello Gavin,

I'm not familiar with the local vs global STP details personally, so I'll let some one else answer that. But as far as fitting with xforms goes, I believe this should work:

from nems.xform_helper import generate_xforms_spec, fit_xfspec
​
uri = '/path_to/your/train_data'
# ld and basic specify the loader and fitter
modelname = 'ld_stp.1-fir.40x32-lvl.1-dexp.1_basic' 
xfspec = generate_xforms_spec(
    recording_uri=uri, modelname=modelname,
    autoPred=False, autoStats=False, autoPlot=False
)
# look at xfspec at this point if you want to see the individual xforms commands
# that this generates
​
# Model output in ctx['rec']['pred'] 
ctx = fit_xfspec(xfspec)

If you still run into issues, @svdavid would probably know best what to tweak.

gavinmischler commented 2 years ago

Thanks so much for the detailed code! The info about the ld_ and _basic parts of the modelname was very useful. I tried running that code but I got this error:

TypeError: fit_basic() missing 1 required positional argument: 'est'

When searching through the documentation for this issue, I found the Intro to Nems notebook, and it seemed pretty straightforward, so I tried following that. Here is the code I tried to use:

    resp_signal_train = RasterizedSignal(recording=recording_name_train,
                name='resp',
                data=Y_train_cat,
                chans=chans,
                fs=100,
                epochs=epochs)

    stim_signal_train = RasterizedSignal(recording=recording_name_train,
                name='stim',
                data=X_train_cat,
                fs=100,
                epochs=epochs)

    resp_signal_test = RasterizedSignal(recording=recording_name_test,
                name='resp',
                data=Y_test_cat,
                chans=chans,
                fs=100)

    stim_signal_test = RasterizedSignal(recording=recording_name_test,
                name='stim',
                data=X_test_cat,
                fs=100)
    rec_train = Recording({'resp': resp_signal_train, 'stim': stim_signal_train})
    rec_test = Recording({'resp': resp_signal_test, 'stim': stim_signal_test})

    # create modelspec and xforms

    # ld and basic specify the loader and fitter
    modelname = 'stp.1-fir.23x40-lvl.1-dexp.1' 

    meta = {'cellids': ''0',
        'batch': '0',
        'modelname': modelname,
        'recording': '0'
       }
    modelspec = nems.initializers.from_keywords(modelname, meta=meta)
    modelspec = nems.analysis.api.fit_basic(rec_train, modelspec, fitter=scipy_minimize)

    # predict on test data
    rec_train, rec_test = nems.analysis.api.generate_prediction(rec_train, rec_test, modelspec)

    # evaluate prediction accuracy
    modelspec = nems.analysis.api.standard_correlation(rec_train, rec_test, modelspec)
    train_corrs.append(modelspec.meta['r_fit'])
    test_corrs.append(modelspec.meta['r_test'])

Initially, I was using "fir40x23" in the modelname because I thought it was (time x freq) but it raised an error saying something like "only found 23 channels for 40 fir filters", so I flipped the order to fir23x40.

The Issue

Everything seems to run fine, but the fit error never decreases, so I feel like I must be doing something wrong. If you have any suggestions I would really appreciate it. Here is a full traceback of the logging during this code:

[nems.initializers INFO] kw: stp.1
[nems.initializers INFO] kw: fir.23x40
[nems.initializers INFO] kw: lvl.1
[nems.initializers INFO] kw: dexp.1
[nems.initializers INFO] Setting modelspec[0] input to stim
[nems.utils INFO] model save destination: /share/naplab/projects/dstrf/stp_model/NEMS/results/0/DATA/0.stp.1_fir.23x40_lvl.1_dexp.1.unknown_fitter.2022-07-07T180115
[nems.modelspec INFO] Freezing fast rec at start=0
[nems.fitters.fitter INFO] options {'ftol': 1e-07, 'maxiter': 1000, 'maxfun': 10000}
0
/share/naplab/projects/dstrf/stp_model/NEMS/results/0/DATA/0.stp.1_fir.23x40_lvl.1_dexp.1.unknown_fitter.2022-07-07T180115 /share/naplab/projects/dstrf/stp_model/NEMS/results/0/DATA/0.stp.1_fir.23x40_lvl.1_dexp.1.unknown_fitter.2022-07-07T180115
[nems.fitters.fitter INFO] Start sigma: [ 0.0399  0.01    0.      0.1    -0.05    0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.1    -0.05
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.1    -0.05    0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.1    -0.05    0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.1    -0.05    0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.1    -0.05    0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.1    -0.05    0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.1    -0.05    0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.1
 -0.05    0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.1    -0.05    0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.1    -0.05
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.1    -0.05    0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.1    -0.05    0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.1    -0.05    0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.1    -0.05    0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.1    -0.05    0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.1    -0.05    0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.1
 -0.05    0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.1    -0.05    0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.1    -0.05
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.1    -0.05    0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.1    -0.05    0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.1    -0.05    0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      0.      0.      0.      0.
  0.      0.      0.      0.      0.      5.      0.      1.      0.    ]
[nems.analysis.cost_functions INFO] Eval #100. E=2.940204
[nems.analysis.cost_functions INFO] Eval #200. E=2.940204
[nems.analysis.cost_functions INFO] Eval #300. E=2.940204
[nems.analysis.cost_functions INFO] Eval #400. E=2.940204
[nems.analysis.cost_functions INFO] Eval #500. E=2.940204
[nems.analysis.cost_functions INFO] Eval #600. E=2.940204
[nems.analysis.cost_functions INFO] Eval #700. E=2.940204
[nems.analysis.cost_functions INFO] Eval #800. E=2.940204
jacobpennington commented 2 years ago

Hello Gavin,

Apologies for the issues, the 'est' error is due to some hard-coded formatting expectations that we're working on eliminating. As for the rest of the code, this does actually appear to be working (again, with the qualification that I'm not personally sure if the modelspec name is appropriate, but it does at least evaluate without error). The "Eval #X" logging is based on the number of times the model has been evaluated during the fit (maximum of "maxfun=10000" in this case). That happens every time a new parameter set is evaluated, but that new set may not be kept. In this case, the fitter went through 800 new parameter combinations and none of them were better than the original, so the error stayed the same. I can't say off the top of my head if that stretch of no changes is a "normal" length to expect, but if you stopped the fit at that point then you may just need to let it run longer (STP models often take several hours to run on our data).

As an example, here's some minimal "working" code using random data, similar to your code:

import numpy as np

from nems.analysis.api import fit_basic
from nems.initializers import from_keywords
from nems.recording import Recording
from nems.fitters.api import scipy_minimize
from nems.signal import RasterizedSignal

Y_train_cat = np.random.rand(50, 10000)  # 50 channels/neurons, 10000 time bins
X_train_cat = np.random.rand(32, 10000)  # 32-channel spectrogram, 10000 time bins
resp_signal_train = RasterizedSignal(recording='train', name='resp', data=Y_train_cat, fs=100)
stim_signal_train = RasterizedSignal(recording='train', name='stim', data=X_train_cat, fs=100)
rec_train = Recording({'resp': resp_signal_train, 'stim': stim_signal_train})

modelname = 'stp.1-fir.32x40-lvl.1-dexp.1'
meta = {'cellids': '0', 'batch': '0', 'modelname': modelname, 'recording': '0'}
modelspec = from_keywords(modelname, meta=meta)

fit_kwargs = {'options': {'ftol': 1e-03, 'maxiter': 200}}  # Set coarser tolerance, fewer iterations for faster testing
new_modelspec = fit_basic(rec_train, modelspec, fitter=scipy_minimize, fit_kwargs=fit_kwargs)

rec_with_prediction = new_modelspec.evaluate(rec_train)
print(f"Pred got generated?: {'pred' in rec_with_prediction.signals}")

With the following (abbreviated) output:

[nems.configs.defaults INFO] Saving log messages to /tmp/nems/NEMS 2022-07-07 123300.log
Backend QtAgg is interactive backend. Turning interactive mode on.
[nems.initializers INFO] kw: stp.1
[nems.initializers INFO] kw: fir.32x40
# ... etc ...
[nems.fitters.fitter INFO] Start sigma: [0.0399 0.01   0.     ... 0.     1.     0.    ]
[nems.analysis.cost_functions INFO] Eval #100. E=13.608059
# ... no change for a while ...
[nems.analysis.cost_functions INFO] Eval #1200. E=13.608059
[nems.analysis.cost_functions INFO] Eval #1300. E=2.119599
# ... no change again for a while ...
[nems.analysis.cost_functions INFO] Eval #2500. E=2.119599
[nems.fitters.fitter INFO] Stopped due to: b'STOP: TOTAL NO. of f AND g EVALUATIONS EXCEEDS LIMIT'
[nems.fitters.fitter INFO] Starting error: 13.608059 -- Final error: 2.119599
# Final "sigma" (model parameters) is different
[nems.fitters.fitter INFO] Final sigma: [ 0.0405  0.0125 -0.0272 ... -0.0387  0.9589  0.0557]
[nems.analysis.fit_basic INFO] Delta error: 13.608059 - 2.119599 = -1.148846e+01

Pred got generated?: True
gavinmischler commented 2 years ago

Thank you so much! This is exactly what I needed! You were right, I just needed to wait a little longer, and the error did eventually drop. I'll just let it go for several hours and see where it ends up.

gavinmischler commented 2 years ago

As a follow-up question, is there a difference between doing the code above, where all the 50 channels were put together into a single recording and trained on, and splitting it up into 50 recordings and training on each one individually? If I run the code above, I was expecting the predicted output to be 50 channels, but actually, it's only 1:

>>> rec_with_prediction.signals['pred'].shape
(1, 10000)

So, if I want a prediction for each channel individually, it seems like I need to train a separate model for each channel. Is that true, or am I missing a way to properly predict all 50 channels at once?

svdavid commented 2 years ago

Note for @gavinmischler : trying to think through your model. stp.1-fir.23x40-lvl.1-dexp.1 I'm interpreting this as applying a linear filter to a 23-channel spectrogram input. Are you trying to apply a different stp to each input channel? If that's the case, you'll need the spec to be 'stp.23.q-fir.23x40-lvl.1-dexp.1'. That will fit quite slowly though with the scipy fitter. If you have a GPU available, you might go with the tensorflow backend fitter instead for fit_basic (nems.tf.cnnlink_new.fit_tf). Also, the q option on stp is the "quick" implementation Menoua and I came up with. It also runs faster -- both for fit_basic and fit_tf.

jacobpennington commented 2 years ago

As for fitting all 50 channels at once, you'll need to modify a bit more. Starting from the modelspec @svdavid gave, 'stp.23.q-fir.23x40-lvl.1-dexp.1' would produce a 1-channel prediction, 'stp.23.q-fir.23x40x50-lvl.50-dexp.50' would produce a 50-channel prediction. Alternatively, 'stp.23.q-fir.23x40.R-lvl.R-dexp.R' will match the number of channels in the response.

I think these (multi-neuron) models still work (slowly) with the basic fitter, but I've only used them with the tensorflow fitter that @svdavid mentioned.

gavinmischler commented 2 years ago

Thank you both for the comments. I have been trying to use the fit_tf method, where I simply replaced scipy_minimize in my code above with fit_tf

new_modelspec = nems.analysis.api.fit_basic(rec_train, modelspec, fitter=fit_tf, fit_kwargs=fit_kwargs)

But this raises some errors:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_12624/3463090911.py in <module>
     79         new_modelspec = None
     80         with io.capture_output() as captured:
---> 81             new_modelspec = nems.analysis.api.fit_basic(rec_train, modelspec, fitter=fit_tf, fit_kwargs=fit_kwargs)
     82     #     modelspec = nems.analysis.api.fit_basic(rec_train, modelspec, fitter=scipy_minimize)
     83 

/share/naplab/projects/dstrf/stp_model/NEMS/nems/analysis/fit_basic.py in fit_basic(data, modelspec, fitter, cost_function, segmentor, mapper, metric, metaname, fit_kwargs, require_phi)
    104     # (might only be one in list, but still should be packaged as a list)
    105     print(fit_kwargs)
--> 106     improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs)
    107     improved_modelspec = unpacker(improved_sigma)
    108     elapsed_time = (time.time() - start_time)

/share/naplab/projects/dstrf/stp_model/NEMS/nems/tf/cnnlink_new.py in fit_tf(modelspec, est, use_modelspec_init, optimizer, max_iter, cost_function, early_stopping_steps, early_stopping_tolerance, early_stopping_val_split, learning_rate, variable_learning_rate, batch_size, seed, initializer, filepath, freeze_layers, IsReload, epoch_name, use_tensorboard, kernel_regularizer, **context)
    203         # if job is running on slurm, need to change model checkpoint dir
    204         # keep a record of the job id
--> 205         modelspec.meta['slurm_jobid'] = job_id
    206 
    207         log_dir_root = Path('/mnt/scratch')

AttributeError: 'list' object has no attribute 'meta'

It seems like the issue is because scipy_minimize and fit_tf don't have the same API (scipy_minimize expects sigma, cost_fn, bounds, etc, but fit_tf expects modelspec, est,...), so when it calls fitter(sigma, cost_fn it doesn't work if the fitter is fit_tf. So it seems like I'm not supposed to provide fitter=fit_tf, but I can't figure out another way to do it.

The other thing I tried was going back to the xforms version of this pipeline:

        est = Recording({'resp': resp_signal_train, 'stim': stim_signal_train})
        val = Recording({'resp': resp_signal_test, 'stim': stim_signal_test})

        # create modelspec and xforms

        # ld and basic specify the loader and fitter
        modelname = 'stp.1.q-fir.23x40-lvl.1-dexp.1' #'fir.23x40-lvl.1-dexp.1' #'stp.1-fir.23x40-lvl.1-dexp.1' 

        meta = {'cellids': '0',
            'batch': '0',
            'modelname': modelname,
            'recording': '0'
           }

        xfspec = []
        xfspec.append(['nems.xforms.init_from_keywords',
               {'keywordstring': modelname,
                'meta': meta
               }])
        # init, then fit
        xfspec.append(['nems.tf.cnnlink.fit_tf_init', {'est': est}])
        xfspec.append(['nems.tf.cnnlink.fit_tf', {'est': est}])

        xfspec.append(['nems.xforms.predict', {'est': est, 'val': val}])

        xfspec.append(['nems.analysis.api.standard_correlation', {},
               ['est', 'val', 'modelspec'], ['modelspec']])
        ctx = {}
        for xfa in xfspec:
            print(xfa)
            ctx = xforms.evaluate_step(xfa, ctx)

This one however eventually raises an error. It has a super long traceback so I've put it in a gist here.

jacobpennington commented 2 years ago

Hello Gavin,

For the first version, you would call fit_tf directly, not as an option within fit_basic (e.g. new_modelspec = fit_tf(modelspec, est). Regarding the last code snip you posted (the xforms version), you would need to change nems.tf.cnnlink to nems.tf.cnnlink_new.

I'll work on some test code to check for other issues, but those should be some simple changes to get things moving for now.

jacobpennington commented 2 years ago

Here's a working version for the non-xforms way, there were just a couple extra tweaks to make:

import numpy as np

from nems.tf.cnnlink_new import fit_tf_init, fit_tf
from nems.initializers import from_keywords
from nems.recording import Recording
from nems.signal import RasterizedSignal

Y_train_cat = np.random.rand(50, 10000)  # 50 channels/neurons, 10000 time bins
# Your first post said 32 channels, second said 23, not sure which is correct.
# Switch 32 -> 23 as needed if I picked the wrong one.
X_train_cat = np.random.rand(32, 10000)  # 32-channel spectrogram, 10000 time bins

resp_signal_train = RasterizedSignal(recording='est', name='resp', data=Y_train_cat, fs=100)
stim_signal_train = RasterizedSignal(recording='est', name='stim', data=X_train_cat, fs=100)
est = Recording({'resp': resp_signal_train, 'stim': stim_signal_train})

modelname = 'stp.32.q-fir.32x40x50-lvl.50-dexp.50'
meta = {'cellids': '0', 'batch': '0', 'modelname': modelname, 'recording': '0'}
modelspec = from_keywords(modelname, meta=meta)

# Fewer iterations, faster learning rate for quicker testing
# Defaults: maxiter 10000, learning_rate 1e-4
# Note that these test values  may result in some fit terminations due to NaN
# weights, that's fine. Just set them  back to the defaults when you're ready
# to try a full fit.
# 'epoch_name' = None to prevent error from some unintentional
# hard-coding of lab format
fit_kwargs = {'maxiter': 100, 'learning_rate': 1e-1, 'epoch_name': None}

# the tf fitters return dicts for compatibility with xforms,
# instead of return a modelspec directly
init_modelspec = fit_tf_init(modelspec, est, **fit_kwargs)['modelspec']
final_modelspec = fit_tf(init_modelspec, est, **fit_kwargs)['modelspec']

new_est = final_modelspec.evaluate(est)
print(new_est['pred'].shape)

And for the xforms version:

import numpy as np

from nems.tf.cnnlink_new import fit_tf_init, fit_tf
from nems.initializers import from_keywords
from nems.recording import Recording
from nems.signal import RasterizedSignal
import nems.xforms as xforms

Y_train_cat = np.random.rand(50, 10000)  # 50 channels/neurons, 10000 time bins
# Your first post said 32 channels, second said 23, not sure which is correct.
# Switch 32 -> 23 as needed if I picked the wrong one.
X_train_cat = np.random.rand(32, 10000)  # 32-channel spectrogram, 10000 time bins
resp_signal_train = RasterizedSignal(recording='est', name='resp', data=Y_train_cat, fs=100)
stim_signal_train = RasterizedSignal(recording='est', name='stim', data=X_train_cat, fs=100)
est = Recording({'resp': resp_signal_train, 'stim': stim_signal_train})

modelname = 'stp.32.q-fir.32x40x50-lvl.50-dexp.50'
meta = {'cellids': '0', 'batch': '0', 'modelname': modelname, 'recording': '0'}
xfspec = []
xfspec.append(['nems.xforms.init_from_keywords',
               {'keywordstring': modelname,
                'meta': meta
                }])

# init, then fit
# (init is optional, but generally gets better results)
fit_kwargs = {'maxiter': 100, 'learning_rate': 1e-1, 'epoch_name': None}
xfspec.append(['nems.tf.cnnlink_new.fit_tf_init', {'est': est, **fit_kwargs}])
xfspec.append(['nems.tf.cnnlink_new.fit_tf', {'est': est, **fit_kwargs}])

ctx = {}
for xfa in xfspec:
    print(xfa)
    ctx = xforms.evaluate_step(xfa, ctx)

final_modelspec = ctx['modelspec']
new_est = final_modelspec.evaluate(est)
print(new_est['pred'].shape)
gavinmischler commented 2 years ago

Thank you so so so much for all your help @jacobpennington, I can now start training and it seems to be working.

For the sake of anyone in the future who comes across this thread, the only remaining thing I needed to change to get it to work was a line in cnnlink_new.fit_tf:

log_dir_root = Path('/mnt/scratch')

I had to change this to a path that actually exists for my machine

log_dir_root = Path('/path/that/actually/exists/for/me')