Exploring neural spike data#
In this section, we’ll be exploring single unit electrophysiology data recorded by Tim Darlington in Lisberger Lab. The data are recordings from three units in macaque frontal eye field (FEF) stored in Matlab’s .mat
format and consist of three variables:
spikes
: a cell array in which each cell contains an array of spike timestamps (trials x units, measured in milliseconds)H
andV
: horizontal and vertical eye positions (trials x time points, sampled at 1000 samples per second)
Loading the data#
Exercise
Download the data from Google Drive. In Colab, you can do this by running
!gdown --id 1bsqBXmV0SrMwycyRR0jY5gyw7wqjYQI_
where the long string is the file’s Google Drive identifier.
Load the data. You will want to
import scipy.io
and useloadmat
to load the data. The result is a dictionary with each key a variable.Write code that extracts the number of units, trials, and timepoints from the data.
Solutions#
Show code cell content
# 1. Download the data
!gdown --id 1bsqBXmV0SrMwycyRR0jY5gyw7wqjYQI_
Show code cell content
# 2. Load the data
import numpy as np
import scipy.io
mat_dict = scipy.io.loadmat("fef_spikes.mat")
spikes = mat_dict["spikes"]
H,V = mat_dict["H"], mat_dict["V"]
Show code cell content
# 3. Write code that extracts the number of units, trials,
# and timepoints from the data.
n_trials = spikes.shape[0]
n_units = spikes.shape[1]
n_timepoints = H.shape[1]
print(f"Number of trials: {n_trials}\nNumber of units: {n_units}\nNumber of time points: {n_timepoints}")
Number of trials: 52
Number of units: 3
Number of time points: 550
Converting from events to time series#
As we learned about in the previous section, we often need to convert from a list of event times to a time series representation of the same data. For instance, rather than listing all times at which a spike occurred, we might want to know how many spikes occurred at each moment in time. That is, we’d like to make a histogram of spike counts using some small bin size.
To perform the conversion, we’ll start by taking only a small subset of the data, spikes for a single unit on a single trial, and converting that list to a time series.
Exercise
Extract the spike times for unit 1 on trial 1.
Get a count of spikes during each millisecond of the trial. Note that spike timestamps are recorded in milliseconds.
Write a test or tests to check whether you got the right answer. Each test should be a line of code that evaluates to true or false based on the raw data and time series. (Hint: tests often check mathematical properties that should be true both before and after some operation is performed.)
Solutions#
Show code cell content
# 1. Extract the spike times for unit 1 on trial 1
spks = spikes[0,0].flatten()
Show code cell content
# 2. Get spike counts
hist, bin_edges = np.histogram(spks, bins=range(n_timepoints+1))
Show code cell content
# 3. Write tests
# sum over histogram = # total spikes
print(sum(hist) == spks.size)
# the locations of the nonzero histogram bins are
# the timestamps of the spikes (rounded down)
print(np.all(hist.nonzero()[0] == np.floor(spks)))
True
True
Refactor: multiple trials and units#
Now that you have code that works on a single list, you can generalize to code that works on multiple trials and multiple units.
Exercise
Generalize your code to do this conversion over multiple trials and units. The result should be a time points x trials x units array.
Extend your test code to perform the same sets of checks on the array created by this new code.
Solutions#
Show code cell content
# 1. Generalize over multiple trials and units.
bins = range(n_timepoints+1)
spkbin = np.zeros((n_timepoints, n_trials, n_units))
for i, trial in enumerate(spikes):
for j, unit in enumerate(trial):
spkbin[:,i,j] = np.histogram(unit.flatten(), bins)[0]
Show code cell content
# 2. Perform the same sets of checks on the array.
for i in range(n_trials):
for j in range(n_units):
if not sum(spkbin[:,i,j]) == spikes[i,j].size:
print("Test 1 failed")
if not np.all(spkbin[:,i,j].nonzero()[0] == np.floor(spikes[i,j])):
print("Test 2 failed")
print("All tests passed")
All tests passed
Refactor: make it a function#
Once we have a pattern in code we know we’ll be repeating, it’s helpful to pull it into a function for several reasons:
Functions make our code more readable. Functions give names to blocks of code, and when we choose these names to reflect what the code does, the flow of logic in our program becomes clearer.
Functions help keep our code DRY. When changes need to be made, we only have to edit one location.
Functions make our code safer. Does your code create lots of temporary little variables that you need only for intermediate steps. Do you sometimes reuse these variable names later? Functions have well-defined inputs and outputs, but all other variables inside only exist for the lifetime of the function. This keeps our workspace neater and avoids some subtle bugs.
Exercise
Pull your spike binning code into a function. That function should take as input a cell array like the
spikes
variable, a minimum, and a maximum time, and return the same kind of spike count array as before.Make sure your tests still pass on the output of the function.
Solutions#
Show code cell content
# 1. Binning code as a function. Times assumed to be in ms.
def spikes_to_bin(spikes, t_min, t_max):
if t_max < t_min:
return np.zeros((0,n_trials,n_units))
bins = range(t_min,t_max+1)
spkbin = np.zeros((t_max-t_min, n_trials, n_units))
for i, trial in enumerate(spikes):
for j, unit in enumerate(trial):
spkbin[:,i,j] = np.histogram(unit.flatten(), bins)[0]
return spkbin
Show code cell content
# 2. Make sure tests still pass.
def test(spkbin, spikes, t_min, t_max):
for i in range(n_trials):
for j in range(n_units):
spk_trial = spikes[i,j].flatten()
spks = spk_trial[(t_min <= spk_trial) & (spk_trial <= t_max)]
if not sum(spkbin[:,i,j]) == spks.size:
print("Test 1 failed for trial %i, unit %i, t_min: %i, t_max: %i" % (i,j,t_min,t_max))
return
# If there is a spike at t_max, we consider it to fall between [t_max-1,t_max], inclusive
if np.any(spks == t_max):
spks[spks == t_max] -= 1.
if not np.all(t_min+spkbin[:,i,j].nonzero()[0] == np.floor(spks)):
print("Test 2 failed for trial %i, unit %i, t_min: %i, t_max: %i" % (i,j,t_min,t_max))
return
print("All tests passed for t_min: %i, t_max: %i" % (t_min,t_max))
for t_min in [0,100,250,n_timepoints]:
for t_max in [0,100,250,n_timepoints]:
if t_max <= t_min: continue
spkbin = spikes_to_bin(spikes, t_min=t_min, t_max=t_max)
test(spkbin, spikes, t_min=t_min, t_max=t_max)
All tests passed for t_min: 0, t_max: 100
All tests passed for t_min: 0, t_max: 250
All tests passed for t_min: 0, t_max: 550
All tests passed for t_min: 100, t_max: 250
All tests passed for t_min: 100, t_max: 550
All tests passed for t_min: 250, t_max: 550
Refactor: generalizing to different bin widths#
There are plenty of cases where we’d like to get spike counts over larger bins.
Exercise
Generalize your function to take a bin width as a parameter.
Make sure your tests still pass on the output of the function.
Solutions:#
Show code cell content
# 1. Function with bin width as a parameter. Times assumed to be in ms.
def spikes_to_bin_varwidth(spikes, t_min=0, t_max=n_timepoints, bin_width=1):
if t_max < t_min:
return np.zeros((0,n_trials,n_units))
bins = np.arange(t_min,t_max+bin_width,bin_width)
spkbin = np.zeros(((t_max-t_min)//bin_width, n_trials, n_units))
for i, trial in enumerate(spikes):
for j, unit in enumerate(trial):
spkbin[:,i,j] = np.histogram(
unit.flatten(),
bins)[0]
return spkbin
Show code cell content
# 2. Make sure tests still pass.
for t_min in [0,100,250,n_timepoints]:
for t_max in [0,100,250,n_timepoints]:
if t_max <= t_min: continue
spkbin = spikes_to_bin_varwidth(spikes, t_min=t_min, t_max=t_max, bin_width=1)
test(spkbin, spikes, t_min=t_min, t_max=t_max)
All tests passed for t_min: 0, t_max: 100
All tests passed for t_min: 0, t_max: 250
All tests passed for t_min: 0, t_max: 550
All tests passed for t_min: 100, t_max: 250
All tests passed for t_min: 100, t_max: 550
All tests passed for t_min: 250, t_max: 550
Visualizing spike data#
A common format for displaying raw event data is the raster plot, in which each row represents one trial, time increases along the horizontal axis, and events are indicated by dots or tick marks.
With our array of spike counts in time, we have all we need to construct the simplest version of this plot. If we think of the first two dimensions of the count array (time and trial) as dimensions of an image, we can plot one image per unit by taking a slice of the count array and using plotting the result as pixels.
Exercise
Plot a PSTH using the data in the count array. Time should be the horizontal axis. (Hint:
matshow
scales the colormap to the range of data.)
Solutions:#
Show code cell content
import matplotlib.pyplot as plt
%config InlineBackend.figure_formats = ['svg']
# 1. Raster plot of spike data
t_min = 0
t_max = n_timepoints
bin_width = 10
spk_bin = spikes_to_bin_varwidth(spikes, bin_width=bin_width)
X, Y = np.mgrid[slice(t_min, t_max + bin_width, bin_width), slice(0,n_trials+1,1)]
vmax = spk_bin.max()
fig, axes = plt.subplots(1,3,figsize=(11,2));
for n, ax in enumerate(axes):
im = ax.pcolormesh(X, Y, spk_bin[:,:,n], vmin=0, vmax=vmax)
ax.set_title("Unit " + str(n+1))
ax.set_xlabel("Time (ms)")
if n == 0:
ax.set_ylabel("Trial")
else:
ax.set_yticks([])
plt.subplots_adjust(hspace=0.1)
cbar = fig.colorbar(im,ax=axes)
cbar.set_label("Spike count");
Refactor: adding behavior#
We might also like to compare spiking with behavior. We can use the subplots
command to make multipanel figures that facilitate comparisons between the two. Let’s make a figure with three rows and a single column that plot spiking, the horizontal, and the vertical eye positions as a function of time.
Exercise
Make the plot described above. Make sure the axes align. (Hint: you may want to use
xlim
andylim
to adjust for alignment and better readability.) Eye traces should all be plotted in the same color. Each panel should be titled.
Solutions:#
Show code cell content
# 1. Make a figure with behavior.
fig, axes = plt.subplots(3,1,figsize=(4,9));
im1 = axes[0].pcolormesh(X,Y,spk_bin[:,:,0]);
im2 = axes[1].pcolormesh(H,cmap="jet")
im3 = axes[2].pcolormesh(V,cmap="jet")
plt.colorbar(im1,ax=axes[0])
plt.colorbar(im2,ax=axes[1])
plt.colorbar(im3,ax=axes[2])
axes[0].set_title("Unit 1")
axes[1].set_title("Horizontal position")
axes[2].set_title("Vertical position")
for i in range(3):
axes[i].set_ylabel("Trials")
axes[2].set_xlabel("Time (ms)")
fig.tight_layout()
Refactor: make it a function#
For plots we make often, it can be useful to repackage code into a function.
Exercise
Create a function to produce this plot. The function should take all needed data as inputs and only handle the plotting and layout.
Solutions:#
Show code cell content
# 1. In function form:
def plot_summary(unit, spk_bin, H, V):
fig, axes = plt.subplots(3,1,figsize=(4,9));
im1 = axes[0].pcolormesh(X, Y, spk_bin[:,:,unit]);
im2 = axes[1].pcolormesh(H,cmap="jet")
im3 = axes[2].pcolormesh(V,cmap="jet")
plt.colorbar(im1,ax=axes[0])
plt.colorbar(im2,ax=axes[1])
plt.colorbar(im3,ax=axes[2])
axes[0].set_title("Unit 1")
axes[1].set_title("Horizontal position")
axes[2].set_title("Vertical position")
fig.tight_layout()
plot_summary(2, spk_bin, H, V)
Show code cell content
import os
os.remove('fef_spikes.mat')