Source code for brainbox.io.spikeglx
import shutil
import logging
from pathlib import Path
import time
import json
import string
import random
import numpy as np
from one.alf.files import remove_uuid_string
import spikeglx
_logger = logging.getLogger('ibllib')
[docs]
def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, car=True):
"""
Extracts spike waveforms from binary ephys data file, after (optionally)
common-average-referencing (CAR) spatial noise.
Parameters
----------
ephys_file : string
The file path to the binary ephys data.
ts : ndarray_like
The timestamps (in s) of the spikes.
ch : ndarray_like
The channels on which to extract the waveforms.
t : numeric (optional)
The time (in ms) of each returned waveform.
sr : int (optional)
The sampling rate (in hz) that the ephys data was acquired at.
n_ch_probe : int (optional)
The number of channels of the recording.
car: bool (optional)
A flag to perform CAR before extracting waveforms.
Returns
-------
waveforms : ndarray
An array of shape (#spikes, #samples, #channels) containing the waveforms.
Examples
--------
1) Extract all the waveforms for unit1 with and without CAR.
>>> import numpy as np
>>> import brainbox as bb
>>> import one.alf.io as alfio
>>> import ibllib.ephys.spikes as e_spks
(*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
# Get a clusters bunch and a units bunch from a spikes bunch from an alf directory.
>>> clstrs_b = alfio.load_object(path_to_alf_out, 'clusters')
>>> spks_b = alfio.load_object(path_to_alf_out, 'spikes')
>>> units_b = bb.processing.get_units_bunch(spks, ['times'])
# Get the timestamps and 20 channels around the max amp channel for unit1, and extract the
# two sets of waveforms.
>>> ts = units_b['times']['1']
>>> max_ch = max_ch = clstrs_b['channels'][1]
>>> if max_ch < 10: # take only channels greater than `max_ch`.
>>> ch = np.arange(max_ch, max_ch + 20)
>>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`.
>>> ch = np.arange(max_ch - 20, max_ch)
>>> else: # take `n_c_ch` around `max_ch`.
>>> ch = np.arange(max_ch - 10, max_ch + 10)
>>> wf = bb.io.extract_waveforms(path_to_ephys_file, ts, ch, car=False)
>>> wf_car = bb.io.extract_waveforms(path_to_ephys_file, ts, ch, car=True)
"""
# Get memmapped array of `ephys_file`
with spikeglx.Reader(ephys_file) as s_reader:
file_m = s_reader.data # the memmapped array
n_wf_samples = int(sr / 1000 * (t / 2)) # number of samples to return on each side of a ts
ts_samples = np.array(ts * sr).astype(int) # the samples corresponding to `ts`
t_sample_first = ts_samples[0] - n_wf_samples
# Exception handling for impossible channels
ch = np.asarray(ch)
ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch
if np.any(ch < 0) or np.any(ch > n_ch_probe):
raise Exception('At least one specified channel number is impossible. '
f'The minimum channel number was {np.min(ch)}, '
f'and the maximum channel number was {np.max(ch)}. '
'Check specified channel numbers and try again.')
if car: # compute spatial noise in chunks
# see https://github.com/int-brain-lab/iblenv/issues/5
raise NotImplementedError("CAR option is not available")
# Initialize `waveforms`, extract waveforms from `file_m`, and CAR.
waveforms = np.zeros((len(ts), 2 * n_wf_samples, ch.size))
# Give time estimate for extracting waveforms.
t0 = time.perf_counter()
for i in range(5):
waveforms[i, :, :] = \
file_m[i * n_wf_samples * 2 + t_sample_first:
i * n_wf_samples * 2 + t_sample_first + n_wf_samples * 2, ch].reshape(
(n_wf_samples * 2, ch.size))
dt = time.perf_counter() - t0
print('Performing waveform extraction. Estimated time is {:.2f} mins. ({})'
.format(dt * len(ts) / 60 / 5, time.ctime()))
for spk, _ in enumerate(ts): # extract waveforms
spk_ts_sample = ts_samples[spk]
spk_samples = np.arange(spk_ts_sample - n_wf_samples, spk_ts_sample + n_wf_samples)
# have to reshape to add an axis to broadcast `file_m` into `waveforms`
waveforms[spk, :, :] = \
file_m[spk_samples[0]:spk_samples[-1] + 1, ch].reshape((spk_samples.size, ch.size))
print('Done. ({})'.format(time.ctime()))
return waveforms
[docs]
class Streamer(spikeglx.Reader):
"""
pid = 'e31b4e39-e350-47a9-aca4-72496d99ff2a'
one = ONE()
sr = Streamer(pid=pid, one=one)
raw_voltage = sr[int(t0 * sr.fs):int((t0 + nsecs) * sr.fs), :]
"""
def __init__(self, pid, one, typ='ap', cache_folder=None, remove_cached=False):
self.target_dir = None # last chunk directory download or read
self.one = one
self.pid = pid
self.cache_folder = cache_folder or Path(self.one.alyx._par.CACHE_DIR).joinpath('cache', typ)
self.remove_cached = remove_cached
self.eid, self.pname = self.one.pid2eid(pid)
self.file_chunks = self.one.load_dataset(self.eid, f'*.{typ}.ch', collection=f"*{self.pname}")
meta_file = self.one.load_dataset(self.eid, f'*.{typ}.meta', collection=f"*{self.pname}")
cbin_rec = self.one.list_datasets(self.eid, collection=f"*{self.pname}", filename=f'*{typ}.*bin', details=True)
self.url_cbin = self.one.record2url(cbin_rec)[0]
with open(self.file_chunks, 'r') as f:
self.chunks = json.load(f)
self.chunks['chunk_bounds'] = np.array(self.chunks['chunk_bounds'])
super(Streamer, self).__init__(meta_file, ignore_warnings=True)
[docs]
def read(self, nsel=slice(0, 10000), csel=slice(None), sync=True, volts=True):
"""
overload the read function by downloading the necessary chunks
"""
first_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.start) - 1)
last_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.stop) - 1)
n0 = self.chunks['chunk_bounds'][first_chunk]
_logger.debug(f'Streamer: caching sample {n0}, (t={n0 / self.fs})')
self.cache_folder.mkdir(exist_ok=True, parents=True)
sr, file_cbin = self._download_raw_partial(first_chunk=first_chunk, last_chunk=last_chunk)
if not volts:
data = np.copy(sr._raw[nsel.start - n0:nsel.stop - n0, csel])
else:
data = sr[nsel.start - n0: nsel.stop - n0, csel]
sr.close()
if self.remove_cached:
shutil.rmtree(self.target_dir)
return data
def _download_raw_partial(self, first_chunk=0, last_chunk=0):
"""
downloads one or several chunks of a mtscomp file and copy ch files and metadata to return
a spikeglx.Reader object
:param first_chunk:
:param last_chunk:
:return: spikeglx.Reader of the current chunk, Pathlib.Path of the directory where it is stored
:return: cbin local path
"""
assert str(self.url_cbin).endswith('.cbin')
webclient = self.one.alyx
relpath = Path(self.url_cbin.replace(webclient._par.HTTP_DATA_SERVER, '.')).parents[0]
# write the temp file into a subdirectory
tdir_chunk = f"chunk_{str(first_chunk).zfill(6)}_to_{str(last_chunk).zfill(6)}"
# for parallel processes, there is a risk of collisions if the removed cached flag is set to True
# if the folder is to be removed append a unique identifier to avoid having duplicate names
if self.remove_cached:
tdir_chunk += ''.join([random.choice(string.ascii_letters) for _ in np.arange(10)])
self.target_dir = Path(self.cache_folder, relpath, tdir_chunk)
Path(self.target_dir).mkdir(parents=True, exist_ok=True)
ch_file_stream = self.target_dir.joinpath(self.file_chunks.name).with_suffix('.stream.ch')
# Get the first sample index, and the number of samples to download.
i0 = self.chunks['chunk_bounds'][first_chunk]
ns_stream = self.chunks['chunk_bounds'][last_chunk + 1] - i0
total_samples = self.chunks['chunk_bounds'][-1]
# handles the meta file
meta_local_path = ch_file_stream.with_suffix('.meta')
if not meta_local_path.exists():
shutil.copy(self.file_chunks.with_suffix('.meta'), meta_local_path)
# if the cached version happens to be the same as the one on disk, just load it
if ch_file_stream.exists() and ch_file_stream.with_suffix('.cbin').exists():
with open(ch_file_stream, 'r') as f:
cmeta_stream = json.load(f)
if (cmeta_stream.get('chopped_first_sample', None) == i0 and
cmeta_stream.get('chopped_total_samples', None) == total_samples):
return spikeglx.Reader(ch_file_stream.with_suffix('.cbin'), ignore_warnings=True), ch_file_stream
else:
shutil.copy(self.file_chunks, ch_file_stream)
assert ch_file_stream.exists()
cmeta = self.chunks.copy()
# prepare the metadata file
cmeta['chunk_bounds'] = cmeta['chunk_bounds'][first_chunk:last_chunk + 2]
cmeta['chunk_bounds'] = [int(_ - i0) for _ in cmeta['chunk_bounds']]
assert len(cmeta['chunk_bounds']) >= 2
assert cmeta['chunk_bounds'][0] == 0
first_byte = cmeta['chunk_offsets'][first_chunk]
cmeta['chunk_offsets'] = cmeta['chunk_offsets'][first_chunk:last_chunk + 2]
cmeta['chunk_offsets'] = [_ - first_byte for _ in cmeta['chunk_offsets']]
assert len(cmeta['chunk_offsets']) >= 2
assert cmeta['chunk_offsets'][0] == 0
n_bytes = cmeta['chunk_offsets'][-1]
assert n_bytes > 0
# Save the chopped chunk bounds and offsets.
cmeta['sha1_compressed'] = None
cmeta['sha1_uncompressed'] = None
cmeta['chopped'] = True
cmeta['chopped_first_sample'] = int(i0)
cmeta['chopped_samples'] = int(ns_stream)
cmeta['chopped_total_samples'] = int(total_samples)
with open(ch_file_stream, 'w') as f:
json.dump(cmeta, f, indent=2, sort_keys=True)
# Download the requested chunks
retries = 0
while True:
try:
cbin_local_path = webclient.download_file(
self.url_cbin, chunks=(first_byte, n_bytes),
target_dir=self.target_dir, clobber=True, return_md5=False)
break
except Exception as e:
retries += 1
if retries > 5:
raise e
_logger.warning(f'Failed to download chunk {first_chunk} to {last_chunk}, retrying')
time.sleep(1)
cbin_local_path_renamed = remove_uuid_string(cbin_local_path).with_suffix('.stream.cbin')
cbin_local_path.replace(cbin_local_path_renamed)
assert cbin_local_path_renamed.exists()
reader = spikeglx.Reader(cbin_local_path_renamed, ignore_warnings=True)
return reader, cbin_local_path