Homework: refactoring¶
In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
#can install with "pip install wget" if necessary
import wget
Firing rates¶
In [2]:
# 1. Load the data
# wget.download("https://people.duke.edu/~jmp33/quantitative-neurobio/data/week1/roitman_data.csv")
# Or download it first to the current directory
df = pd.read_csv("roitman_data.csv")
data = df.values
In [3]:
df.head()
Out[3]:
count | time | trial | stimulus | coherence | choice | correct | unit | into_RF | |
---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0 | 1 | 71 | 0 | 1 | 1 | 107 | 1 |
1 | 0 | 0 | 22 | 71 | 0 | 1 | 1 | 107 | 1 |
2 | 0 | 0 | 31 | 71 | 0 | 1 | 1 | 107 | 1 |
3 | 0 | 0 | 46 | 71 | 0 | 1 | 1 | 107 | 1 |
4 | 0 | 0 | 61 | 71 | 0 | 1 | 1 | 107 | 1 |
In [4]:
# 2. Parse the data
count = data[:,0]
timebins = data[:,1]
trials = data[:,2]
stimulus = data[:,3]
coherence = data[:,4]
choice = data[:,5]
unit = data[:,7]
In [5]:
n_timebins = np.unique(timebins).size
n_trials = np.unique(trials).size
n_stimulus = np.unique(stimulus).size
n_coherence = np.unique(coherence).size
n_choice = np.unique(choice).size
n_units = np.unique(unit).size
n_timebins, n_trials, n_stimulus, n_coherence, n_choice, n_units
Out[5]:
(209, 88, 12, 6, 2, 38)
Let's plot the total row count of the data matrix for each trial number. If each condition had the same number of trials, and trials lasted a fixed duration, then this would be a flat distribution.
In [6]:
plt.figure(figsize=(14,3))
plt.bar(np.unique(trials), np.bincount(trials)[np.unique(trials)]);
plt.xlabel("Trials");
plt.ylabel("Count");
Let's construct a matrix of spike counts indexed by time, coherence, and choice. We will need to sum over stimuli, trials, and units
In [7]:
spk_bin = np.zeros((n_timebins, n_coherence, n_choice))
trial_count = np.zeros_like(spk_bin)
# We will iterate over each the unique coherence level (val),
# using coh_idx as an index to modify spk_bin
for coh_idx, val in enumerate(np.unique(coherence)[:]):
print("Processing coherence level %i" % val)
# Find the row indices of the data matrix that match the coherence level
idxs = np.nonzero(coherence == val)[0]
# Iterate over these rows
for i in idxs:
count, time, choice = data[i,[0,1,5]] # Extract the first, second, and sixth entries of row i
t_idx = int(time/5) # Convert the time into an index
if choice == 1: choice_idx = 0
elif choice == 2: choice_idx = 1
else: raise Exception("Unknown choice value")
# Record the spike and trial counts
spk_bin[t_idx,coh_idx,choice_idx] += count
trial_count[t_idx,coh_idx,choice_idx] += 1
Processing coherence level 0 Processing coherence level 32 Processing coherence level 64 Processing coherence level 128 Processing coherence level 256 Processing coherence level 512
We now need to convert the summed counts to an average. We'll divide by the trial count, and by the original time bin
In [8]:
spk_bin_avg = np.zeros_like(spk_bin)
spk_bin_avg = spk_bin/trial_count/5e-3
Let's plot the unfiltered, average firing rate for coherence level 0 and choice 1
In [9]:
plt.figure(figsize=(6,4));
plt.plot(np.arange(0,n_timebins*5,5), spk_bin_avg[:,0,0]);
plt.ylabel("Rate")
plt.xlabel("Time")
Out[9]:
Text(0.5, 0, 'Time')
Let's filter the average spike counts using the boxcar filtering covered in class
In [10]:
def rate_estimate(spikes, w):
estimate = np.zeros_like(spikes)
for i in range(n_coherence):
for j in range(n_choice):
y = np.convolve(w, spikes[:,i,j], mode='same')
estimate[:,i,j] = y
return estimate
In [11]:
filter_width = 20 # in the unit of bins
r = rate_estimate(spk_bin_avg, np.ones(filter_width)/filter_width)
And plot the reproduction of figure 9A
In [12]:
fig = plt.figure(figsize=(3,5));
ax = plt.subplot(111)
colors = ["orange","purple","blue","red","green","black"][::-1]
handlers = []
for coh_idx in range(n_coherence)[::-1]:
for choice_idx in range(n_choice):
if choice_idx == 0:
linestyle="solid"
else:
linestyle="dashed"
h = plt.plot(
np.arange(0,n_timebins*5,5),
r[:,coh_idx,choice_idx],
color=colors[coh_idx],
linestyle=linestyle,
alpha=0.75)
if choice_idx == 0:
handlers.append(h[0])
plt.xlabel("Time (ms)")
plt.ylabel("Firing rate (sp/s)");
plt.ylim([20,60]);
#cut off the first 60 ms of the data since those are averaged with zero by our smooting window
plt.xlim([60,1000])
labels = ["51.2","25.6","12.8","6.4","3.2","0"]
title="Motion strength"
ax.legend(
handlers,
labels,
title=title,
bbox_to_anchor=(1.05, 1.02));
In [ ]: