Source code for brainbox.io.spikeglx

import shutil
import logging
from pathlib import Path
import time
import json
import string
import random

import numpy as np
from one.alf.files import remove_uuid_string

import spikeglx

_logger = logging.getLogger('ibllib')


[docs] def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, car=True): """ Extracts spike waveforms from binary ephys data file, after (optionally) common-average-referencing (CAR) spatial noise. Parameters ---------- ephys_file : string The file path to the binary ephys data. ts : ndarray_like The timestamps (in s) of the spikes. ch : ndarray_like The channels on which to extract the waveforms. t : numeric (optional) The time (in ms) of each returned waveform. sr : int (optional) The sampling rate (in hz) that the ephys data was acquired at. n_ch_probe : int (optional) The number of channels of the recording. car: bool (optional) A flag to perform CAR before extracting waveforms. Returns ------- waveforms : ndarray An array of shape (#spikes, #samples, #channels) containing the waveforms. Examples -------- 1) Extract all the waveforms for unit1 with and without CAR. >>> import numpy as np >>> import brainbox as bb >>> import one.alf.io as alfio >>> import ibllib.ephys.spikes as e_spks (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) # Get a clusters bunch and a units bunch from a spikes bunch from an alf directory. >>> clstrs_b = alfio.load_object(path_to_alf_out, 'clusters') >>> spks_b = alfio.load_object(path_to_alf_out, 'spikes') >>> units_b = bb.processing.get_units_bunch(spks, ['times']) # Get the timestamps and 20 channels around the max amp channel for unit1, and extract the # two sets of waveforms. >>> ts = units_b['times']['1'] >>> max_ch = max_ch = clstrs_b['channels'][1] >>> if max_ch < 10: # take only channels greater than `max_ch`. >>> ch = np.arange(max_ch, max_ch + 20) >>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`. >>> ch = np.arange(max_ch - 20, max_ch) >>> else: # take `n_c_ch` around `max_ch`. >>> ch = np.arange(max_ch - 10, max_ch + 10) >>> wf = bb.io.extract_waveforms(path_to_ephys_file, ts, ch, car=False) >>> wf_car = bb.io.extract_waveforms(path_to_ephys_file, ts, ch, car=True) """ # Get memmapped array of `ephys_file` with spikeglx.Reader(ephys_file) as s_reader: file_m = s_reader.data # the memmapped array n_wf_samples = int(sr / 1000 * (t / 2)) # number of samples to return on each side of a ts ts_samples = np.array(ts * sr).astype(int) # the samples corresponding to `ts` t_sample_first = ts_samples[0] - n_wf_samples # Exception handling for impossible channels ch = np.asarray(ch) ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch if np.any(ch < 0) or np.any(ch > n_ch_probe): raise Exception('At least one specified channel number is impossible. ' f'The minimum channel number was {np.min(ch)}, ' f'and the maximum channel number was {np.max(ch)}. ' 'Check specified channel numbers and try again.') if car: # compute spatial noise in chunks # see https://github.com/int-brain-lab/iblenv/issues/5 raise NotImplementedError("CAR option is not available") # Initialize `waveforms`, extract waveforms from `file_m`, and CAR. waveforms = np.zeros((len(ts), 2 * n_wf_samples, ch.size)) # Give time estimate for extracting waveforms. t0 = time.perf_counter() for i in range(5): waveforms[i, :, :] = \ file_m[i * n_wf_samples * 2 + t_sample_first: i * n_wf_samples * 2 + t_sample_first + n_wf_samples * 2, ch].reshape( (n_wf_samples * 2, ch.size)) dt = time.perf_counter() - t0 print('Performing waveform extraction. Estimated time is {:.2f} mins. ({})' .format(dt * len(ts) / 60 / 5, time.ctime())) for spk, _ in enumerate(ts): # extract waveforms spk_ts_sample = ts_samples[spk] spk_samples = np.arange(spk_ts_sample - n_wf_samples, spk_ts_sample + n_wf_samples) # have to reshape to add an axis to broadcast `file_m` into `waveforms` waveforms[spk, :, :] = \ file_m[spk_samples[0]:spk_samples[-1] + 1, ch].reshape((spk_samples.size, ch.size)) print('Done. ({})'.format(time.ctime())) return waveforms
[docs] class Streamer(spikeglx.Reader): """ pid = 'e31b4e39-e350-47a9-aca4-72496d99ff2a' one = ONE() sr = Streamer(pid=pid, one=one) raw_voltage = sr[int(t0 * sr.fs):int((t0 + nsecs) * sr.fs), :] """ def __init__(self, pid, one, typ='ap', cache_folder=None, remove_cached=False): self.target_dir = None # last chunk directory download or read self.one = one self.pid = pid self.cache_folder = cache_folder or Path(self.one.alyx._par.CACHE_DIR).joinpath('cache', typ) self.remove_cached = remove_cached self.eid, self.pname = self.one.pid2eid(pid) self.file_chunks = self.one.load_dataset(self.eid, f'*.{typ}.ch', collection=f"*{self.pname}") meta_file = self.one.load_dataset(self.eid, f'*.{typ}.meta', collection=f"*{self.pname}") cbin_rec = self.one.list_datasets(self.eid, collection=f"*{self.pname}", filename=f'*{typ}.*bin', details=True) self.url_cbin = self.one.record2url(cbin_rec)[0] with open(self.file_chunks, 'r') as f: self.chunks = json.load(f) self.chunks['chunk_bounds'] = np.array(self.chunks['chunk_bounds']) super(Streamer, self).__init__(meta_file, ignore_warnings=True)
[docs] def read(self, nsel=slice(0, 10000), csel=slice(None), sync=True, volts=True): """ overload the read function by downloading the necessary chunks """ first_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.start) - 1) last_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.stop) - 1) n0 = self.chunks['chunk_bounds'][first_chunk] _logger.debug(f'Streamer: caching sample {n0}, (t={n0 / self.fs})') self.cache_folder.mkdir(exist_ok=True, parents=True) sr, file_cbin = self._download_raw_partial(first_chunk=first_chunk, last_chunk=last_chunk) if not volts: data = np.copy(sr._raw[nsel.start - n0:nsel.stop - n0, csel]) else: data = sr[nsel.start - n0: nsel.stop - n0, csel] sr.close() if self.remove_cached: shutil.rmtree(self.target_dir) return data
def _download_raw_partial(self, first_chunk=0, last_chunk=0): """ downloads one or several chunks of a mtscomp file and copy ch files and metadata to return a spikeglx.Reader object :param first_chunk: :param last_chunk: :return: spikeglx.Reader of the current chunk, Pathlib.Path of the directory where it is stored :return: cbin local path """ assert str(self.url_cbin).endswith('.cbin') webclient = self.one.alyx relpath = Path(self.url_cbin.replace(webclient._par.HTTP_DATA_SERVER, '.')).parents[0] # write the temp file into a subdirectory tdir_chunk = f"chunk_{str(first_chunk).zfill(6)}_to_{str(last_chunk).zfill(6)}" # for parallel processes, there is a risk of collisions if the removed cached flag is set to True # if the folder is to be removed append a unique identifier to avoid having duplicate names if self.remove_cached: tdir_chunk += ''.join([random.choice(string.ascii_letters) for _ in np.arange(10)]) self.target_dir = Path(self.cache_folder, relpath, tdir_chunk) Path(self.target_dir).mkdir(parents=True, exist_ok=True) ch_file_stream = self.target_dir.joinpath(self.file_chunks.name).with_suffix('.stream.ch') # Get the first sample index, and the number of samples to download. i0 = self.chunks['chunk_bounds'][first_chunk] ns_stream = self.chunks['chunk_bounds'][last_chunk + 1] - i0 total_samples = self.chunks['chunk_bounds'][-1] # handles the meta file meta_local_path = ch_file_stream.with_suffix('.meta') if not meta_local_path.exists(): shutil.copy(self.file_chunks.with_suffix('.meta'), meta_local_path) # if the cached version happens to be the same as the one on disk, just load it if ch_file_stream.exists() and ch_file_stream.with_suffix('.cbin').exists(): with open(ch_file_stream, 'r') as f: cmeta_stream = json.load(f) if (cmeta_stream.get('chopped_first_sample', None) == i0 and cmeta_stream.get('chopped_total_samples', None) == total_samples): return spikeglx.Reader(ch_file_stream.with_suffix('.cbin'), ignore_warnings=True), ch_file_stream else: shutil.copy(self.file_chunks, ch_file_stream) assert ch_file_stream.exists() cmeta = self.chunks.copy() # prepare the metadata file cmeta['chunk_bounds'] = cmeta['chunk_bounds'][first_chunk:last_chunk + 2] cmeta['chunk_bounds'] = [int(_ - i0) for _ in cmeta['chunk_bounds']] assert len(cmeta['chunk_bounds']) >= 2 assert cmeta['chunk_bounds'][0] == 0 first_byte = cmeta['chunk_offsets'][first_chunk] cmeta['chunk_offsets'] = cmeta['chunk_offsets'][first_chunk:last_chunk + 2] cmeta['chunk_offsets'] = [_ - first_byte for _ in cmeta['chunk_offsets']] assert len(cmeta['chunk_offsets']) >= 2 assert cmeta['chunk_offsets'][0] == 0 n_bytes = cmeta['chunk_offsets'][-1] assert n_bytes > 0 # Save the chopped chunk bounds and offsets. cmeta['sha1_compressed'] = None cmeta['sha1_uncompressed'] = None cmeta['chopped'] = True cmeta['chopped_first_sample'] = int(i0) cmeta['chopped_samples'] = int(ns_stream) cmeta['chopped_total_samples'] = int(total_samples) with open(ch_file_stream, 'w') as f: json.dump(cmeta, f, indent=2, sort_keys=True) # Download the requested chunks retries = 0 while True: try: cbin_local_path = webclient.download_file( self.url_cbin, chunks=(first_byte, n_bytes), target_dir=self.target_dir, clobber=True, return_md5=False) break except Exception as e: retries += 1 if retries > 5: raise e _logger.warning(f'Failed to download chunk {first_chunk} to {last_chunk}, retrying') time.sleep(1) cbin_local_path_renamed = remove_uuid_string(cbin_local_path).with_suffix('.stream.cbin') cbin_local_path.replace(cbin_local_path_renamed) assert cbin_local_path_renamed.exists() reader = spikeglx.Reader(cbin_local_path_renamed, ignore_warnings=True) return reader, cbin_local_path