"""
Low-level functions to work in frequency domain for n-dim arrays
"""
import numpy as np
import scipy.fft
from ibllib.dsp.utils import fcn_cosine
[docs]def convolve(x, w, mode='full'):
"""
Frequency domain convolution along the last dimension (2d arrays)
Will broadcast if a matrix is convolved with a vector
:param x:
:param w:
:return: convolution
"""
nsx = x.shape[-1]
nsw = w.shape[-1]
ns = ns_optim_fft(nsx + nsw)
x_ = np.concatenate((x, np.zeros([*x.shape[:-1], ns - nsx], dtype=x.dtype)), axis=-1)
w_ = np.concatenate((w, np.zeros([*w.shape[:-1], ns - nsw], dtype=w.dtype)), axis=-1)
xw = np.real(np.fft.irfft(np.fft.rfft(x_, axis=-1) * np.fft.rfft(w_, axis=-1), axis=-1))
xw = xw[..., :(nsx + nsw)] # remove 0 padding
if mode == 'full':
return xw
elif mode == 'same':
first = int(np.floor(nsw / 2)) - ((nsw + 1) % 2)
last = int(np.ceil(nsw / 2)) + ((nsw + 1) % 2)
return xw[..., first:-last]
[docs]def ns_optim_fft(ns):
"""
Gets the next higher combination of factors of 2 and 3 than ns to compute efficient ffts
:param ns:
:return: nsoptim
"""
p2, p3 = np.meshgrid(2 ** np.arange(25), 3 ** np.arange(15))
sz = np.unique((p2 * p3).flatten())
return sz[np.searchsorted(sz, ns)]
[docs]def dephas(w, phase, axis=-1):
"""
dephas a signal by a given angle in degrees
:param w:
:param phase: phase in degrees
:param axis:
:return:
"""
ns = w.shape[axis]
W = freduce(np.fft.fft(w, axis=axis), axis=axis) * np.exp(- 1j * phase / 180 * np.pi)
return np.real(np.fft.ifft(fexpand(W, ns=ns, axis=axis), axis=axis))
[docs]def fscale(ns, si=1, one_sided=False):
"""
numpy.fft.fftfreq returns Nyquist as a negative frequency so we propose this instead
:param ns: number of samples
:param si: sampling interval in seconds
:param one_sided: if True, returns only positive frequencies
:return: fscale: numpy vector containing frequencies in Hertz
"""
fsc = np.arange(0, np.floor(ns / 2) + 1) / ns / si # sample the frequency scale
if one_sided:
return fsc
else:
return np.concatenate((fsc, -fsc[slice(-2 + (ns % 2), 0, -1)]), axis=0)
[docs]def freduce(x, axis=None):
"""
Reduces a spectrum to positive frequencies only
Works on the last dimension (contiguous in c-stored array)
:param x: numpy.ndarray
:param axis: axis along which to perform reduction (last axis by default)
:return: numpy.ndarray
"""
if axis is None:
axis = x.ndim - 1
siz = list(x.shape)
siz[axis] = int(np.floor(siz[axis] / 2 + 1))
return np.take(x, np.arange(0, siz[axis]), axis=axis)
[docs]def fexpand(x, ns=1, axis=None):
"""
Reconstructs full spectrum from positive frequencies
Works on the last dimension (contiguous in c-stored array)
:param x: numpy.ndarray
:param axis: axis along which to perform reduction (last axis by default)
:return: numpy.ndarray
"""
if axis is None:
axis = x.ndim - 1
# dec = int(ns % 2) * 2 - 1
# xcomp = np.conj(np.flip(x[..., 1:x.shape[-1] + dec], axis=axis))
ilast = int((ns + (ns % 2)) / 2)
xcomp = np.conj(np.flip(np.take(x, np.arange(1, ilast), axis=axis), axis=axis))
return np.concatenate((x, xcomp), axis=axis)
[docs]def bp(ts, si, b, axis=None):
"""
Band-pass filter in frequency domain
:param ts: time serie
:param si: sampling interval in seconds
:param b: cutout frequencies: 4 elements vector or list
:param axis: axis along which to perform reduction (last axis by default)
:return: filtered time serie
"""
return _freq_filter(ts, si, b, axis=axis, typ='bp')
[docs]def lp(ts, si, b, axis=None):
"""
Low-pass filter in frequency domain
:param ts: time serie
:param si: sampling interval in seconds
:param b: cutout frequencies: 2 elements vector or list
:param axis: axis along which to perform reduction (last axis by default)
:return: filtered time serie
"""
return _freq_filter(ts, si, b, axis=axis, typ='lp')
[docs]def hp(ts, si, b, axis=None):
"""
High-pass filter in frequency domain
:param ts: time serie
:param si: sampling interval in seconds
:param b: cutout frequencies: 2 elements vector or list
:param axis: axis along which to perform reduction (last axis by default)
:return: filtered time serie
"""
return _freq_filter(ts, si, b, axis=axis, typ='hp')
def _freq_filter(ts, si, b, axis=None, typ='lp'):
"""
Wrapper for hp/lp/bp filters
"""
if axis is None:
axis = ts.ndim - 1
ns = ts.shape[axis]
f = fscale(ns, si=si, one_sided=True)
if typ == 'bp':
filc = _freq_vector(f, b[0:2], typ='hp') * _freq_vector(f, b[2:4], typ='lp')
else:
filc = _freq_vector(f, b, typ=typ)
if axis < (ts.ndim - 1):
filc = filc[:, np.newaxis]
return np.real(np.fft.ifft(np.fft.fft(ts, axis=axis) * fexpand(filc, ns, axis=0), axis=axis))
def _freq_vector(f, b, typ='lp'):
"""
Returns a frequency modulated vector for filtering
:param f: frequency vector, uniform and monotonic
:param b: 2 bounds array
:return: amplitude modulated frequency vector
"""
filc = fcn_cosine(b)(f)
if typ.lower() in ['hp', 'highpass']:
return filc
elif typ.lower() in ['lp', 'lowpass']:
return 1 - filc
[docs]def fshift(w, s, axis=-1, ns=None):
"""
Shifts a 1D or 2D signal in frequency domain, to allow for accurate non-integer shifts
:param w: input signal (if complex, need to provide ns too)
:param s: shift in samples, positive shifts forward
:param axis: axis along which to shift (last axis by default)
:param axis: axis along which to shift (last axis by default)
:param ns: if a rfft frequency domain array is provided, give a number of samples as there
is an ambiguity
:return: w
"""
# create a vector that contains a 1 sample shift on the axis
ns = ns or w.shape[axis]
shape = np.array(w.shape) * 0 + 1
shape[axis] = ns
dephas = np.zeros(shape)
np.put(dephas, 1, 1)
dephas = scipy.fft.rfft(dephas, axis=axis)
# fft the data along the axis and the dephas
do_fft = np.invert(np.iscomplexobj(w))
if do_fft:
W = scipy.fft.rfft(w, axis=axis)
else:
W = w
# if multiple shifts, broadcast along the other dimensions, otherwise keep a single vector
if not np.isscalar(s):
s_shape = np.array(w.shape)
s_shape[axis] = 1
s = s.reshape(s_shape)
# apply the shift (s) to the fft angle to get the phase shift and broadcast
W *= np.exp(1j * np.angle(dephas) * s)
if do_fft:
W = np.real(scipy.fft.irfft(W, ns, axis=axis))
W = W.astype(w.dtype)
return W
[docs]def fit_phase(w, si=1, fmin=0, fmax=None, axis=-1):
"""
Performs a linear regression on the unwrapped phase of a wavelet to obtain a time-delay
:param w: wavelet (usually a cross-correlation)
:param si: sampling interval
:param fmin: sampling interval
:param fnax: sampling interval
:param axis:
:return: dt
"""
if fmax is None:
fmax = 1 / si / 2
ns = w.shape[axis]
freqs = freduce(fscale(ns, si=si))
phi = np.unwrap(np.angle(freduce(np.fft.fft(w, axis=axis), axis=axis)))
indf = np.logical_and(fmin < freqs, freqs < fmax)
dt = - np.polyfit(freqs[indf],
np.swapaxes(phi.compress(indf, axis=axis), axis, 0), 1)[0] / np.pi / 2
return dt
[docs]def dft(x, xscale=None, axis=-1, kscale=None):
"""
1D discrete fourier transform. Vectorized.
:param x: 1D numpy array to be transformed
:param xscale: time or spatial index of each sample
:param axis: for multidimensional arrays, axis along which the ft is computed
:param kscale: (optional) fourier coefficient. All if complex input, positive if real
:return: 1D complex numpy array
"""
ns = x.shape[axis]
if xscale is None:
xscale = np.arange(ns)
if kscale is None:
nk = ns if np.any(np.iscomplex(x)) else np.ceil((ns + 1) / 2)
kscale = np.arange(nk)
else:
nk = kscale.size
if axis != 0:
# the axis of the transform always needs to be the first
x = np.swapaxes(x, axis, 0)
shape = np.array(x.shape)
x = np.reshape(x, (ns, int(np.prod(x.shape) / ns)))
# compute fourier coefficients
exp = np.exp(- 1j * 2 * np.pi / ns * xscale * kscale[:, np.newaxis])
X = np.matmul(exp, x)
shape[0] = int(nk)
X = X.reshape(shape)
if axis != 0:
X = np.swapaxes(X, axis, 0)
return X
[docs]def dft2(x, r, c, nk, nl):
"""
Irregularly sampled 2D dft by projecting into sines/cosines. Vectorized.
:param x: vector or 2d matrix of shape (nrc, nt)
:param r: vector (nrc) of normalized positions along the k dimension (axis 0)
:param c: vector (nrc) of normalized positions along the l dimension (axis 1)
:param nk: output size along axis 0
:param nl: output size along axis 1
:return: Matrix X (nk, nl, nt)
"""
# it would be interesting to compare performance with numba straight loops (easier to write)
# GPU/C implementation should implement straight loops
nt = x.shape[-1]
k, h = [v.flatten() for v in np.meshgrid(np.arange(nk), np.arange(nl), indexing='ij')]
# exp has dimension (kh, rc)
exp = np.exp(- 1j * 2 * np.pi * (r[np.newaxis] * k[:, np.newaxis] +
c[np.newaxis] * h[:, np.newaxis]))
return np.matmul(exp, x).reshape((nk, nl, nt))