Source code for brainbox.quality.permutation_test

"""
Quality control for arbitrary metrics, using permutation testing.

Written by Sebastian Bruijns
"""

import numpy as np
import time
import matplotlib.pyplot as plt
# TODO: take in eids and download data yourself?


[docs] def permut_test(data1, data2, metric, n_permut=1000, show=False, title=None): """ Compute the probability of observating metric difference for datasets, via permutation testing. We're taking absolute values of differences, because the order of dataset input shouldn't matter We're only computing means, what if we want to apply a more complicated function to the permutation result? Pay attention to always give one list (even if its just one dataset, but then it doesn't make sense anyway...) Parameters ---------- data1 : array-like First data set, list or array of data-entities to use for permutation test (make data2 optional and then permutation test more similar to tuning sensitivity?) data2 : array-like Second data set, also list or array of data-entities to use for permutation test metric : function, array-like -> float Metric to use for permutation test, will be used to reduce elements of data1 and data2 to one number n_permut : integer (optional) Number of perumtations to use for test plot : Boolean (optional) Whether or not to show a plot of the permutation distribution and a marker for the position of the true difference in relation to this distribution Returns ------- p : float p-value of true difference in permutation distribution See Also -------- TODO: Examples -------- TODO: """ # Calculate metrics and true difference between groups print('data1') print(data1) metrics1 = [metric(d) for d in data1] print('metrics1') print(metrics1) metrics2 = [metric(d) for d in data2] true_diff = np.abs(np.mean(metrics1) - np.mean(metrics2)) # Prepare permutations size1 = len(metrics1) diffs = np.concatenate((metrics1, metrics2)) permutations = np.zeros((n_permut, diffs.size), dtype=np.int32) # Create permutations, could be parallelized or vectorized in principle, but unclear how indizes = np.arange(diffs.size) for i in range(n_permut): np.random.shuffle(indizes) permutations[i] = indizes permut_diffs = np.abs(np.mean(diffs[permutations[:, :size1]], axis=1) - np.mean(diffs[permutations[:, size1:]], axis=1)) p = len(permut_diffs[permut_diffs > true_diff]) / n_permut if show or title: plot_permut_test(permut_diffs=permut_diffs, true_diff=true_diff, p=p, title=title) return p
[docs] def plot_permut_test(permut_diffs, true_diff, p, title=None): """Plot permutation test result.""" n, _, _ = plt.hist(permut_diffs) plt.plot(true_diff, np.max(n) / 20, '*r', markersize=12) # Prettify plot plt.gca().spines['top'].set_visible(False) plt.gca().spines['right'].set_visible(False) plt.title("p = {}".format(p)) if title: plt.savefig(title + '.png') plt.close()
if __name__ == '__main__': rng = np.random.RandomState(2) data1 = rng.normal(0, 1, (23, 5)) data2 = rng.normal(0.1, 1, (32, 5)) t = time.time() p = permut_test(data1, data2, np.mean, plot=True) print(time.time() - t) print(p)