Inference on Hodgin-Huxley model: simple interface

You can also download and run a similar example available here: hh_sbi_simple_interface.py

Here you can download the data: input traces output traces

from brian2 import *
from brian2modelfitting import *
import pandas as pd

To load the data, use the following:

df_inp_traces = pd.read_csv('input_traces_hh.csv')
df_out_traces = pd.read_csv('output_traces_hh.csv')
inp_traces = df_inp_traces.to_numpy()
inp_traces = inp_traces[[0, 1], 1:]
out_traces = df_out_traces.to_numpy()
out_traces = out_traces[[0, 1], 1:]

Then we have to define the model and its parameters:

area = 20_000*um**2
El = -65*mV
EK = -90*mV
ENa = 50*mV
VT = -63*mV
dt = 0.01*ms
eqs = '''
    dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I)/Cm : volt
    dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/
            (exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/
            (exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1
    dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/
            (exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1
    dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1

    # free parameters
    g_na : siemens (constant)
    g_kd : siemens (constant)
    gl   : siemens (constant)
    Cm   : farad (constant)
'''

Let’s also specify time domain for more convenient plotting afterwards:

t = arange(0, out_traces.shape[1]*dt/ms, dt/ms)
stim_start, stim_end = t[where(inp_traces[0, :] != 0)[0][[0, -1]]]

Now, we have to define features in order to create a summary statistics representation of the output data traces:

list_of_features = [
    lambda x: max(x[(t > stim_start) & (t < stim_end)]),  # max active potential
    lambda x: mean(x[(t > stim_start) & (t < stim_end)]),  # mean active potential
    lambda x: std(x[(t > stim_start) & (t < stim_end)]),  # std active potential
    lambda x: mean(x[(t > .25 * stim_start) & (t < .75 * stim_start)]),  # resting
]

We have to instantiate the object by using the class Inferencer in which the data and the list of features should be passed:

inferencer = Inferencer(dt=dt, model=eqs,
                        input={'I': inp_traces*amp},
                        output={'v': out_traces*mV},
                        features={'v': list_of_features},
                        method='exponential_euler',
                        threshold='m > 0.5',
                        refractory='m > 0.5',
                        param_init={'v': 'VT'})

Be sure that the names of parameters passed to the infer method correspond to the names of unknown parameters defined as constatns in the model equations.

posterior = inferencer.infer(n_samples=5_000,
                             n_rounds=3,
                             inference_method='SNPE',
                             density_estimator_model='mdn',
                             gl=[1e-09*siemens, 1e-07*siemens],
                             g_na=[2e-06*siemens, 2e-04*siemens],
                             g_kd=[6e-07*siemens, 6e-05*siemens],
                             Cm=[0.1*uF*cm**-2*area, 2*uF*cm**-2*area])

After the training of the neural density estimator stored accessible through posterior is done, we can draw samples from the approximated posterior distribution as follows:

samples = inferencer.sample((5_000, ))

In order to analyze the sampled data further, we can use the embedded pairplot method which visualizes the pairwise relationship between each two parameters:

limits = {'gl': [1e-9*siemens, 1e-07*siemens],
          'g_na': [2e-06*siemens, 2e-04*siemens],
          'g_kd': [6e-07*siemens, 6e-05*siemens],
          'Cm': [0.1*uF*cm**-2*area, 2*uF*cm**-2*area]}
labels = {'gl': r'$\overline{g}_{l}$',
          'g_na': r'$\overline{g}_{Na}$',
          'g_kd': r'$\overline{g}_{K}$',
          'Cm': r'$C_{m}$'}
inferencer.pairplot(limits=limits,
                    labels=labels,
                    ticks=limits,
                    figsize=(6, 6))
condition = inferencer.sample((1, ))
inferencer.conditional_pairplot(condition=condition,
                                limits=limits,
                                labels=labels,
                                ticks=limits,
                                figsize=(6, 6))

To obtain a simulated trace from a single sample of parameters drawn from posterior distribution, use the following code:

inf_traces = inferencer.generate_traces(output_var='v')

Let us now visualize the recordings and simulated traces:

inf_traces = inferencer.generate_traces(output_var='v')

nrows = 2
ncols = out_traces.shape[0]
fig, axs = subplots(nrows, ncols, sharex=True,
                    gridspec_kw={'height_ratios': [3, 1]}, figsize=(9, 3))
for idx in range(ncols):
    spike_idx = in1d(t, spike_times[idx]).nonzero()[0]
    spike_v = (out_traces[idx, :].min(), out_traces[idx, :].max())
    axs[0, idx].plot(t, out_traces[idx, :].T, 'C3-', lw=3, label='recordings')
    axs[0, idx].plot(t, inf_traces[idx, :].T/mV, 'k--', lw=2,
                     label='sampled traces')
    axs[1, idx].plot(t, inp_traces[idx, :].T/nA, lw=3, c='k', label='stimuli')
    axs[1, idx].set_xlabel('$t$, ms')
    if idx == 0:
        axs[0, idx].set_ylabel('$V$, mV')
        axs[1, idx].set_ylabel('$I$, nA')
handles, labels = [(h + l) for h, l
                   in zip(axs[0, idx].get_legend_handles_labels(),
                   axs[1, idx].get_legend_handles_labels())]
fig.legend(handles, labels)
tight_layout()
show()