# NEURON CLASS
import numpy as np
import models
from scipy import optimize
import inspect
import matplotlib.pyplot as plt
[docs]class Neuron(object):
"""
A Neuron object is a specific instance of a neuron model,
which can exits on its own or as a member of a network.
A Neuron is initialized with ``myneuron = Neuron(params)``, where
params is an optional dict of parameter values for myneuron.
Once initialized, the primary method of a neuron object is to calculate
its evolution given an input.
Parameters
----------
params
A dictionary of parameter values, elements described below:
model
name of the evolution function :math:`\\dot{y}=f(x,y)`
implemented by neuron and described in ``models.py``.
defaults to "identity"
solver
name of ODE solver used to update neuron state and
solve its dynamics, default is "Euler"
dt
time-step to use in ODE solver, defaults to 1.e-4
hist_len
length of list of previous neuron states to store.
default is 10
y0
initial state of neuron, defaults to zeros
mpar
a dictionary of model specific parameters describing
a specific neuron instance. See ``models.py`` for a
given model's parameters
Attributes
----------
y
current neuron state
y0
initial sate of neuron
hist
list of previous neuron states:
hist[0] = y(t); hist[1] = y(t-dt); and so on.
For multidimensional neurons, history only stores output/state variable.
hist_len
length of hist list,
number of previous states neuron stores
dt
time step
model
name of the evolution function :math:`\\dot{y}=f(x,y)`
mpars
list of parameters of evolution function
dim
dimensionality of neuron phase space
"""
def __init__(self, params={}):
self.model = params.get("model", "identity")
self.solver = params.get("solver", "Euler")
# set model ...
self.set_model()
#pdb.set_trace()
# set initial state, default: all zeros
self.set_initial_state(params.get("y0", np.zeros(self.dim)))
# set history, do it after creating y0 and model
self.set_history(params.get("hist_len", 10))
self.set_dt(params.get("dt", 1.0e-4))
# read model specific parameters such as tau
if self.model != 'identity':
self.set_model_params(params.get('mpar', {}))
# set solver ...
if self.solver == 'Euler':
self.step = self.step_Euler
elif self.solver == 'RK4':
self.step = self.step_RK4
else:
raise ValueError("Not implemented")
def __repr__(self):
return "Neuron of type {0:s}".format(self.model)
##### CONSTRUCTOR HELPER FUNCTIONS #####
[docs] def set_dt(self, dt):
"""
Constructor helper function, sets time-step for neuron
Parameters
----------
dt
time-step for neuron ODE solver
"""
self.dt = dt
# for the identity, the step parameter h should be the same as dt
if self.model == "identity":
self.set_model_params({'h': self.dt})
[docs] def set_model(self):
"""
constructor helper function, sets Neuron model and model's key properties
"""
if self.model == 'identity':
self.dim = 1
self.fun = models.identity
elif self.model == 'FitzHughNagumo':
self.dim = 2
self.fun = models.FitzHughNagamo
elif self.model == 'Yamada_0':
self.dim = 2
self.fun = models.Yamada_0
elif self.model == 'Yamada_1':
self.dim = 3
self.fun = models.Yamada_1
elif self.model == 'Yamada_2':
self.dim = 3
self.fun = models.Yamada_2
else:
raise ValueError("Not implemented")
[docs] def set_initial_state(self, y0=None):
"""
Sets neuron initial state and clears neuron history
Parameters
----------
y0
new initial state for neuron
"""
if y0 is None:
y0 = np.zeros(self.dim)
self.y0 = y0
if np.isscalar(self.y0):
self.y = np.array([self.y0])
#wipe history as well, replace with initial state
self.hist = [self.y0]
else:
self.y = self.y0.copy()
#same thing, but history must be list of scalars
self.hist=[self.y0[0]]
if len(self.y) != self.dim:
raise ValueError(
"""The initial state has {0:d} dimensions but the {1:s} model
has a {2:d}-dim phase space
""".format(len(self.y), self.model, self.dim))
# for cnt in np.arange(self.hist_len):
# self.hist.insert(0, self.y.copy())
[docs] def set_history(self, hist_len):
"""
Constructor helper function,
initializes empty history list of previous neuron states
Parameters
----------
hist_len
length of history list
"""
self.hist_len = hist_len
if np.isscalar(self.y0):
self.hist = [self.y0]
else:
self.hist = [self.y0[0]]
[docs] def set_model_params(self, mkwargs):
"""
Constructor helper function, sets neuron's evolution function
Parameters
----------
mkwargs
model specific list of parameters for evolution function
"""
self.f = lambda x, y : self.fun(x, y, **mkwargs)
self.mpars=mkwargs
if len(self.mpars)==0: #if dont pass, mpars is empty
signature = inspect.signature(self.fun) #need to read default parameters from model
self.mpars= {
k: v.default
for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty
}
#### STEPPER FUNCTIONS (THE HEART OF THE ODE SOLVER) ####
[docs] def step_Euler(self, x):
"""
Steps the neuron forward one time step using Euler's method
Computes the state of the neuron at the next time step, given
the current input signal and neuron state:
:math:`y(n+1)=y(n)+dt\\cdot f(x(n), y(n))`
Parameters
----------
x
the input signal at the current time (a scalar)
Returns
----------
np.array
The state of the neuron at t+dt
"""
self.y = self.y + self.dt * self.f(x, self.y)
if self.dim>1:
self.hist.insert(0,self.y[0] )#first element is "state" variable
else:
self.hist.insert(0,self.y)
# trim the history from the back if it grows too big
if len(self.hist) > self.hist_len:
_ = self.hist.pop()
return self.y # return output y (t+dt)
[docs] def step_RK4(self, x):
"""
Steps the neuron forward one time step using
the fourth-order Runge-Kutta method
Computes the state of the neuron at the next time step, given
the current input signal and neuron state:
:math:`y(n+1)=y(n)+dt\\cdot f(x(n), y(n))`
Parameters
----------
x
the input signal at the current time (a scalar)
Returns
----------
np.array
The state of the neuron at t+dt
"""
# RK4 needs to know previous inputs as well
if not hasattr(self, 'x_prev'): self.x_prev = x
k1 = self.f(self.x_prev, self.y)
k2 = self.f(0.5*(x + self.x_prev), self.y + 0.5*self.dt*k1)
k3 = self.f(0.5*(x + self.x_prev), self.y + 0.5*self.dt*k2)
k4 = self.f(x, self.y + self.dt*k3)
self.x_prev = x
self.y = self.y + (k1/6 + k2/3 + k3/3 + k4/6) *self.dt
if self.dim>1:
self.hist.insert(0,self.y[0].copy()) #first element is "state" variable
#also dont think i need .copy() anymore
else:
self.hist.insert(0,self.y)
# trim the history from the back if it grows too big
if len(self.hist) > self.hist_len:
_ = self.hist.pop()
return self.y # return output y (t+dt)
[docs] def solve(self, x):
"""
Calculates the neuron's dynamics in response to an input signal x(t).
Uses the solver defined in ``solver`` to step trough each element of x(t)
and compute the resultant neuron evolution.
Parameters
----------
x
the time-dependent input signal (1d array)
Returns
----------
np.array
the resultant neuron phase space dynamics
as a (neuron.dim X num_timesteps) array
"""
y_out = np.zeros((len(x), self.dim))
y_out[0,:] = self.y # initial state
for i in np.arange(len(x)-1):
y_out[i+1,:] = self.step(x[i])
return y_out
[docs] def steady_state(self, yguess=None):
"""
solve for the no-input steady-state of the neuron.
Choose yguess "wisely" to avoid unphysical roots. We recommend
testing steady state by running ``myneuron.solve(np.zeros(N))``
and verifying the dynamics converge to the calculated steady-state
Parameters
----------
yguess
an initial guess of the steady-state; method will return fixed-point
closest to yguess. Default is ``myneuron.y0``
Returns
----------
np.array
The steady state of the neuron as calculated by setting
:math:`f(0, y)=0` and solving for y nearest to yguess
"""
if yguess is None:
yguess=self.y0
ODEs=lambda y: self.f(0, y)
Root=optimize.fsolve(ODEs, yguess)
return Root
[docs] def visualize_plot(self, x_in, output, time=None, ysteady=None):
"""
Generate a simple and easy to read plot of the neuron dynamics.
After solving a neuron for a given input signal, pass this and computed
dynamics to generate a plot.
Use returned figure handle to update plot parameters from defaults
if desired.
Parameters
-----------
time
an array of time points which input and output are plotted over
x_in
the time-dependent input signal (1d array)
outputs
the resultant neuron phase space dynamics
as a (neuron.dim X num_timesteps) array
ysteady
Optional neuron steady state to include in plot
Returns
----------
matplotlib.figure.Figure
A matplotlib figure instance showing the network dynamics
"""
Len_t=output.shape[0] #length of time vector
msg2="input and output should both the same temporal length"
assert ( len(x_in)==int(Len_t) ), msg2
if time is None: #didnt pass time, so compute from dt and signal length TL
time=np.linspace(0., (Len_t-1)*self.dt, num=Len_t)
colors=['b', 'r', 'g', 'c'] #use these when plotting
fig=plt.figure()
ax1=fig.add_axes([0,0.0, 1, 0.6])
ax2=ax1.twinx()
ax3=fig.add_axes([0,.7, 1, 0.3])
if ysteady is not None: #if have steady states, plot them first
ax1.plot(time, ysteady[0]*np.ones(Len_t), '--'+colors[0], linewidth=2)
for ind in range(1, self.dim):
if ind==2 and (self.model == 'Yamada_1' or self.model == 'Yamada_2' ):
ax2.plot(time, -ysteady[ind]*np.ones(Len_t), '--'+colors[ind], linewidth=2)
else:
ax2.plot(time, ysteady[ind]*np.ones(Len_t), '--'+colors[ind], linewidth=2)
# plot Neuron state and input current
ax1.plot(time, output[:,0], 'b')
for ind in range(1, self.dim):
if ind==2 and (self.model == 'Yamada_1' or self.model == 'Yamada_2' ):
#want to flip Q->-Q in this case
ax2.plot(time, -output[:, ind], '-.'+colors[ind])
else:
ax2.plot(time, output[:, ind], '-.'+colors[ind])
ax3.plot(time, x_in, '-k')
ax1.set_xlabel('t [$1/\gamma$]')
ax1.set_ylabel('$I$ [arb units]')
ax3.set_ylabel('$i_{in}$ [$\gamma$]')
ax2.set_ylabel('$J$ [arb units]')
ax1.set_xlim(time[0], time[-1])
ax2.set_xlim(time[0], time[-1])
ax3.set_xlim(time[0], time[-1])
if (self.model == 'Yamada_1' or self.model == 'Yamada_2' ): #want to flip Q->-Q in this case
ax2.set_ylabel('$G,\,-Q$ [arb units]')
return fig