brian-team / brian2modelfitting

Model fitting toolbox for the Brian 2 simulator
https://brian2modelfitting.readthedocs.io
Other
14 stars 6 forks source link

enable gpu support #61

Closed akapet00 closed 3 years ago

akapet00 commented 3 years ago

Resolves #62.

This PR allows the user to pass the device which can be either cpu or gpu (or cuda which is identical to passing gpu) to infer and subsequently to init_inference methods in Inferencer.

This feature targets advanced users that won't use default embedding network for automatic feature extraction, rather will customize and pass their own embedding networks that take advantage of convolutions or some other procedures that run much faster on GPU. For regular user, changing device on which torch will operate, won't do much.

Example: automatic feature extraction by using custom CNN embedding network, ref. Rodrigues, P. L. C. and Gramfort, A. Learning summary features of time series for likelihood free inference, in proceedings of the Third Workshop on Machine Learning and the Physical Sciences (NeurIPS 2020).

Data traces have been generated from the simulator as defined in: https://github.com/brian-team/brian2/blob/master/examples/advanced/modelfitting_sbi.py#L33-L91:

from brian2 import *
from brian2modelfitting import *
from torch import nn
import torch
import time

inp_trace = load('../data/input_traces_sim.npy')
out_trace = load('../data/output_traces_sim.npy')

Parameters and model:

dt = 0.05*ms
t = arange(0, out_trace.size*dt/ms, dt/ms)
t_start, t_end = t[np.where(inp_trace != 0)[0][[0, -1]]]

gL = 10*nS
EL = -70*mV
VT = -60.0*mV
C = 200*pF
ENa = 53*mV
EK = -107*mV
ground_truth_params = {'gNa': 30*uS, 'gK': 1*uS}
init_conds = {'Vm': 'EL',
              'm': '1/(1 + betam/alpham)',
              'h': '1/(1 + betah/alphah)',
              'n': '1/(1 + betan/alphan)'}

# Define a model with free 2 free parameters: gNa and gK
eqs = '''
     dVm/dt = -(gNa*m**3*h*(Vm - ENa) + gK*n**4*(Vm - EK) + gL*(Vm - EL) - I) / C : volt
     dm/dt = alpham*(1-m) - betam*m : 1
     dn/dt = alphan*(1-n) - betan*n : 1
     dh/dt = alphah*(1-h) - betah*h : 1

     alpham = (-0.32/mV) * (Vm - VT - 13.*mV) / (exp((-(Vm - VT - 13.*mV))/(4.*mV)) - 1)/ms : Hz
     betam = (0.28/mV) * (Vm - VT - 40.*mV) / (exp((Vm - VT - 40.*mV)/(5.*mV)) - 1)/ms : Hz
     alphah = 0.128 * exp(-(Vm - VT - 17.*mV) / (18.*mV))/ms : Hz
     betah = 4/(1 + exp((-(Vm - VT - 40.*mV)) / (5.*mV)))/ms : Hz
     alphan = (-0.032/mV) * (Vm - VT - 15.*mV) / (exp((-(Vm - VT - 15.*mV)) / (5.*mV)) - 1)/ms : Hz
     betan = 0.5*exp(-(Vm - VT - 10.*mV) / (40.*mV))/ms : Hz

     # The parameters to fit
     gNa : siemens (constant)
     gK : siemens (constant)
     '''

Inferencer instantiation without features argument -- automatic feature extraction will be performed:

inferencer = Inferencer(dt=dt, model=eqs,
                        input={'I': inp_trace.reshape(1, -1)*amp},
                        output={'Vm': out_trace.reshape(1, -1)*mV},
                        method='exponential_euler',
                        threshold='m>0.5', refractory='m>0.5',
                        param_init=init_conds)

Custom embedding network as defined in the ref. previously outlined:

# Supercool embedding net
class YuleNet(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=8, kernel_size=64,
                               stride=1, padding=32, bias=True)
        self.relu1 = nn.ReLU()
        pooling1 = 16
        self.pool1 = nn.AvgPool1d(kernel_size=pooling1)

        self.conv2 = nn.Conv1d(in_channels=8, out_channels=8, kernel_size=64,
                               stride=1, padding=32, bias=True)
        self.relu2 = nn.ReLU()
        pooling2 = int((in_features // pooling1) // 16)
        self.pool2 = nn.AvgPool1d(kernel_size=pooling2)

        self.dropout = nn.Dropout(p=0.50)

        self.linear = nn.Linear(#in_features=8 * in_features // (pooling1 * pooling2),
                                in_features=128,
                                out_features=out_features)
        self.relu3 = nn.ReLU()

    def forward(self, x):
        if x.ndim == 1:
            x = x.view(1, 1, -1)
        else:
            x = x.view(len(x), 1, -1)
        x_conv1 = self.conv1(x)
        x_relu1 = self.relu1(x_conv1)
        x_pool1 = self.pool1(x_relu1)

        x_conv2 = self.conv2(x_pool1)
        x_relu2 = self.relu2(x_conv2)
        x_pool2 = self.pool2(x_relu2)

        x_flatten = x_pool2.view(len(x), 1, -1)
        x_dropout = self.dropout(x_flatten)

        x = self.relu3(self.linear(x_dropout))
        return x.view(len(x), -1)

Comparison between GPU and CPU training time. Let's start with GPU:

start_time = time.time()
posterior_gpu = inferencer.infer(n_samples=5_000,
                                 inference_method='SNPE',
                                 density_estimator_model='mdn',
                                 inference_kwargs={'embedding_net': YuleNet(out_trace.size, 8)},
                                 train_kwargs={'max_num_epochs': 30, 
                                               'num_atoms': 10,
                                               'training_batch_size': 100,
                                               'use_combined_loss': True,
                                               'discard_prior_samples': True},
                                 device='gpu',
                                 gNa=[.5*uS, 80.*uS],
                                 gK=[1e-4*uS, 15.*uS])
end_time = time.time()
In [1]: print(f'Elapsed training time (GPU): {end_time - start_time:.4f}s')
Out[1]: Elapsed training time (GPU): 68.4958s

And now, let's do the same but instead with CPU:

start_time = time.time()
posterior_cpu = inferencer.infer(n_samples=5_000,
                                 inference_method='SNPE',
                                 density_estimator_model='mdn',
                                 inference_kwargs={'embedding_net': YuleNet(out_trace.size, 8)},
                                 train_kwargs={'max_num_epochs': 30, 
                                               'num_atoms': 10,
                                               'training_batch_size': 100,
                                               'use_combined_loss': True,
                                               'discard_prior_samples': True},
                                 restart=True,
                                 device='cpu',
                                 gNa=[.5*uS, 80.*uS],
                                 gK=[1e-4*uS, 15.*uS])
end_time = time.time()
In [2]: print(f'Elapsed training time (CPU): {end_time - start_time:.4f}s')
Out[2]: Elapsed training time (CPU): 197.0259s

So, in this case GPU clearly wins.

akapet00 commented 3 years ago

I am merging this now. We can discuss possible changes and/or additional examples.