Exploring neural spike data#

In this section, we’ll be exploring single unit electrophysiology data recorded by Tim Darlington in Lisberger Lab. The data are recordings from three units in macaque frontal eye field (FEF) stored in Matlab’s .mat format and consist of three variables:

  • spikes: a cell array in which each cell contains an array of spike timestamps (trials x units, measured in milliseconds)

  • H and V: horizontal and vertical eye positions (trials x time points, sampled at 1000 samples per second)

Loading the data#

Exercise

  1. Download the data from Google Drive. In Colab, you can do this by running

    !gdown --id 1bsqBXmV0SrMwycyRR0jY5gyw7wqjYQI_
    

    where the long string is the file’s Google Drive identifier.

  2. Load the data. You will want to import scipy.io and use loadmat to load the data. The result is a dictionary with each key a variable.

  3. Write code that extracts the number of units, trials, and timepoints from the data.

Solutions#

Hide code cell content
# 1. Download the data

!gdown --id 1bsqBXmV0SrMwycyRR0jY5gyw7wqjYQI_
Hide code cell content
# 2. Load the data

import numpy as np
import scipy.io

mat_dict = scipy.io.loadmat("fef_spikes.mat")
spikes = mat_dict["spikes"]
H,V = mat_dict["H"], mat_dict["V"]
Hide code cell content
# 3. Write code that extracts the number of units, trials,
#    and timepoints from the data.

n_trials = spikes.shape[0]
n_units = spikes.shape[1]
n_timepoints = H.shape[1]
print(f"Number of trials: {n_trials}\nNumber of units: {n_units}\nNumber of time points: {n_timepoints}")
Number of trials: 52
Number of units: 3
Number of time points: 550

Converting from events to time series#

As we learned about in the previous section, we often need to convert from a list of event times to a time series representation of the same data. For instance, rather than listing all times at which a spike occurred, we might want to know how many spikes occurred at each moment in time. That is, we’d like to make a histogram of spike counts using some small bin size.

To perform the conversion, we’ll start by taking only a small subset of the data, spikes for a single unit on a single trial, and converting that list to a time series.

Exercise

  1. Extract the spike times for unit 1 on trial 1.

  2. Get a count of spikes during each millisecond of the trial. Note that spike timestamps are recorded in milliseconds.

  3. Write a test or tests to check whether you got the right answer. Each test should be a line of code that evaluates to true or false based on the raw data and time series. (Hint: tests often check mathematical properties that should be true both before and after some operation is performed.)

Solutions#

Hide code cell content
# 1. Extract the spike times for unit 1 on trial 1

spks = spikes[0,0].flatten()
Hide code cell content
# 2. Get spike counts

hist, bin_edges = np.histogram(spks, bins=range(n_timepoints+1))
Hide code cell content
# 3. Write tests

# sum over histogram = # total spikes
print(sum(hist) == spks.size)

# the locations of the nonzero histogram bins are 
# the timestamps of the spikes (rounded down)
print(np.all(hist.nonzero()[0] == np.floor(spks)))
True
True

Refactor: multiple trials and units#

Now that you have code that works on a single list, you can generalize to code that works on multiple trials and multiple units.

Exercise

  1. Generalize your code to do this conversion over multiple trials and units. The result should be a time points x trials x units array.

  2. Extend your test code to perform the same sets of checks on the array created by this new code.

Solutions#

Hide code cell content
# 1. Generalize over multiple trials and units. 

bins = range(n_timepoints+1)
spkbin = np.zeros((n_timepoints, n_trials, n_units))
for i, trial in enumerate(spikes):
    for j, unit in enumerate(trial):
        spkbin[:,i,j] = np.histogram(unit.flatten(), bins)[0]
Hide code cell content
# 2. Perform the same sets of checks on the array. 

for i in range(n_trials):
    for j in range(n_units):
        if not sum(spkbin[:,i,j]) == spikes[i,j].size:
            print("Test 1 failed")
        if not np.all(spkbin[:,i,j].nonzero()[0] == np.floor(spikes[i,j])):
            print("Test 2 failed")
print("All tests passed")
All tests passed

Refactor: make it a function#

Once we have a pattern in code we know we’ll be repeating, it’s helpful to pull it into a function for several reasons:

  • Functions make our code more readable. Functions give names to blocks of code, and when we choose these names to reflect what the code does, the flow of logic in our program becomes clearer.

  • Functions help keep our code DRY. When changes need to be made, we only have to edit one location.

  • Functions make our code safer. Does your code create lots of temporary little variables that you need only for intermediate steps. Do you sometimes reuse these variable names later? Functions have well-defined inputs and outputs, but all other variables inside only exist for the lifetime of the function. This keeps our workspace neater and avoids some subtle bugs.

Exercise

  1. Pull your spike binning code into a function. That function should take as input a cell array like the spikes variable, a minimum, and a maximum time, and return the same kind of spike count array as before.

  2. Make sure your tests still pass on the output of the function.

Solutions#

Hide code cell content
# 1. Binning code as a function. Times assumed to be in ms. 

def spikes_to_bin(spikes, t_min, t_max):
    if t_max < t_min:
        return np.zeros((0,n_trials,n_units))
    bins = range(t_min,t_max+1)
    spkbin = np.zeros((t_max-t_min, n_trials, n_units))
    for i, trial in enumerate(spikes):
        for j, unit in enumerate(trial):
            spkbin[:,i,j] = np.histogram(unit.flatten(), bins)[0]
    return spkbin
Hide code cell content
# 2. Make sure tests still pass.

def test(spkbin, spikes, t_min, t_max):
    for i in range(n_trials):
        for j in range(n_units):
            spk_trial = spikes[i,j].flatten()
            spks = spk_trial[(t_min <= spk_trial) & (spk_trial <= t_max)]
            if not sum(spkbin[:,i,j]) == spks.size:
                print("Test 1 failed for trial %i, unit %i, t_min: %i, t_max: %i" % (i,j,t_min,t_max))
                return
            # If there is a spike at t_max, we consider it to fall between [t_max-1,t_max], inclusive
            if np.any(spks == t_max):
                spks[spks == t_max] -= 1.
            if not np.all(t_min+spkbin[:,i,j].nonzero()[0] == np.floor(spks)):
                print("Test 2 failed for trial %i, unit %i, t_min: %i, t_max: %i" % (i,j,t_min,t_max))
                return
    print("All tests passed for t_min: %i, t_max: %i" % (t_min,t_max))

for t_min in [0,100,250,n_timepoints]:
    for t_max in [0,100,250,n_timepoints]:
        if t_max <= t_min: continue
        spkbin = spikes_to_bin(spikes, t_min=t_min, t_max=t_max)
        test(spkbin, spikes, t_min=t_min, t_max=t_max)        
All tests passed for t_min: 0, t_max: 100
All tests passed for t_min: 0, t_max: 250
All tests passed for t_min: 0, t_max: 550
All tests passed for t_min: 100, t_max: 250
All tests passed for t_min: 100, t_max: 550
All tests passed for t_min: 250, t_max: 550

Refactor: generalizing to different bin widths#

There are plenty of cases where we’d like to get spike counts over larger bins.

Exercise

  1. Generalize your function to take a bin width as a parameter.

  2. Make sure your tests still pass on the output of the function.

Solutions:#

Hide code cell content
# 1. Function with bin width as a parameter. Times assumed to be in ms.

def spikes_to_bin_varwidth(spikes, t_min=0, t_max=n_timepoints, bin_width=1):
    if t_max < t_min:
        return np.zeros((0,n_trials,n_units))
    bins = np.arange(t_min,t_max+bin_width,bin_width)
    spkbin = np.zeros(((t_max-t_min)//bin_width, n_trials, n_units))
    for i, trial in enumerate(spikes):
        for j, unit in enumerate(trial):
            spkbin[:,i,j] = np.histogram(
                unit.flatten(),
                bins)[0]
    return spkbin
Hide code cell content
# 2. Make sure tests still pass.

for t_min in [0,100,250,n_timepoints]:
    for t_max in [0,100,250,n_timepoints]:
        if t_max <= t_min: continue
        spkbin = spikes_to_bin_varwidth(spikes, t_min=t_min, t_max=t_max, bin_width=1)
        test(spkbin, spikes, t_min=t_min, t_max=t_max)  
All tests passed for t_min: 0, t_max: 100
All tests passed for t_min: 0, t_max: 250
All tests passed for t_min: 0, t_max: 550
All tests passed for t_min: 100, t_max: 250
All tests passed for t_min: 100, t_max: 550
All tests passed for t_min: 250, t_max: 550

Visualizing spike data#

A common format for displaying raw event data is the raster plot, in which each row represents one trial, time increases along the horizontal axis, and events are indicated by dots or tick marks.

With our array of spike counts in time, we have all we need to construct the simplest version of this plot. If we think of the first two dimensions of the count array (time and trial) as dimensions of an image, we can plot one image per unit by taking a slice of the count array and using plotting the result as pixels.

Exercise

  1. Plot a PSTH using the data in the count array. Time should be the horizontal axis. (Hint: matshow scales the colormap to the range of data.)

Solutions:#

Hide code cell content
import matplotlib.pyplot as plt
%config InlineBackend.figure_formats = ['svg']

# 1. Raster plot of spike data 
t_min = 0
t_max = n_timepoints
bin_width = 10
spk_bin = spikes_to_bin_varwidth(spikes, bin_width=bin_width)
X, Y = np.mgrid[slice(t_min, t_max + bin_width, bin_width), slice(0,n_trials+1,1)]
vmax = spk_bin.max()
fig, axes = plt.subplots(1,3,figsize=(11,2));
for n, ax in enumerate(axes):
    im = ax.pcolormesh(X, Y, spk_bin[:,:,n], vmin=0, vmax=vmax)
    ax.set_title("Unit " + str(n+1))
    ax.set_xlabel("Time (ms)")
    if n == 0: 
        ax.set_ylabel("Trial")
    else:
        ax.set_yticks([])
plt.subplots_adjust(hspace=0.1)
cbar = fig.colorbar(im,ax=axes)
cbar.set_label("Spike count");
../_images/4ef1444fdb8a959a512988a165026321496db6a03498f65c47aa7f7a399a667d.svg

Refactor: adding behavior#

We might also like to compare spiking with behavior. We can use the subplots command to make multipanel figures that facilitate comparisons between the two. Let’s make a figure with three rows and a single column that plot spiking, the horizontal, and the vertical eye positions as a function of time.

Exercise

  1. Make the plot described above. Make sure the axes align. (Hint: you may want to use xlim and ylim to adjust for alignment and better readability.) Eye traces should all be plotted in the same color. Each panel should be titled.

Solutions:#

Hide code cell content
# 1. Make a figure with behavior.

fig, axes = plt.subplots(3,1,figsize=(4,9));
im1 = axes[0].pcolormesh(X,Y,spk_bin[:,:,0]);
im2 = axes[1].pcolormesh(H,cmap="jet")
im3 = axes[2].pcolormesh(V,cmap="jet")
plt.colorbar(im1,ax=axes[0])
plt.colorbar(im2,ax=axes[1])
plt.colorbar(im3,ax=axes[2])
axes[0].set_title("Unit 1")
axes[1].set_title("Horizontal position")
axes[2].set_title("Vertical position")
for i in range(3):
    axes[i].set_ylabel("Trials")
axes[2].set_xlabel("Time (ms)")
fig.tight_layout()
../_images/ef5add57bb54a8866dc6f1c3ea7f93fb9e3b27d1866ac4c2677f45e807094f25.svg

Refactor: make it a function#

For plots we make often, it can be useful to repackage code into a function.

Exercise

  1. Create a function to produce this plot. The function should take all needed data as inputs and only handle the plotting and layout.

Solutions:#

Hide code cell content
# 1. In function form:

def plot_summary(unit, spk_bin, H, V):
    fig, axes = plt.subplots(3,1,figsize=(4,9));
    im1 = axes[0].pcolormesh(X, Y, spk_bin[:,:,unit]);
    im2 = axes[1].pcolormesh(H,cmap="jet")
    im3 = axes[2].pcolormesh(V,cmap="jet")
    plt.colorbar(im1,ax=axes[0])
    plt.colorbar(im2,ax=axes[1])
    plt.colorbar(im3,ax=axes[2])
    axes[0].set_title("Unit 1")
    axes[1].set_title("Horizontal position")
    axes[2].set_title("Vertical position")
    fig.tight_layout()

plot_summary(2, spk_bin, H, V)
../_images/bae161e6d9a620a9807eb7287d2247e2cacf2f1a1da5a503a10e0619770c817b.svg
Hide code cell content
import os
os.remove('fef_spikes.mat')