import unittest
from tempfile import TemporaryDirectory
from pathlib import Path
import logging
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from one.api import ONE
from one.alf import spec
from iblutil.util import Bunch
from ibllib.tests import TEST_DB
from ibllib.qc.camera import CameraQC, get_task_collection
from ibllib.io.raw_data_loaders import load_camera_ssv_times
from ibllib.tests.fixtures import utils
[docs]
class TestCameraQC(unittest.TestCase):
backend = ''
[docs]
@classmethod
def setUpClass(cls) -> None:
cls.one = ONE(**TEST_DB)
cls.backend = matplotlib.get_backend()
matplotlib.use('Agg')
[docs]
@classmethod
def tearDownClass(cls) -> None:
if cls.backend:
matplotlib.use(cls.backend)
# Clear overwritten methods by destroying cached instance
ONE.cache_clear()
[docs]
def setUp(self) -> None:
self.tempdir = TemporaryDirectory()
self.session_path = utils.create_fake_session_folder(self.tempdir.name)
utils.create_fake_raw_video_data_folder(self.session_path)
self.eid = 'd3372b15-f696-4279-9be5-98f15783b5bb'
self.qc = CameraQC(self.session_path, 'left', one=self.one, n_samples=5, stream=False)
self.qc._type = 'ephys'
[docs]
def tearDown(self) -> None:
self.tempdir.cleanup()
plt.close('all')
[docs]
def test_check_brightness(self):
self.qc.data['frame_samples'] = self.qc.load_reference_frames('left')
n = len(self.qc.data['frame_samples'])
self.qc.frame_samples_idx = np.arange(n)
self.qc.n_samples = 3
self.assertIs(spec.QC.PASS, self.qc.check_brightness(display=True))
# Check plots
fig = plt.gcf()
self.assertEqual(3, len(fig.axes))
expected = np.array([72.00794903, 70.98228997, 78.25575987])
np.testing.assert_array_almost_equal(fig.axes[0].lines[0]._y, expected)
# Make frames a third as bright
self.qc.data['frame_samples'] = (self.qc.data['frame_samples'] / 3).astype(np.int32)
self.assertIs(spec.QC.WARNING, self.qc.check_brightness())
# Change proportion of passing frames
self.qc.n_samples = 1000
self.assertIs(spec.QC.FAIL, self.qc.check_brightness())
self.qc.n_samples = 3
# Change thresholds
self.qc.data['frame_samples'] = self.qc.load_reference_frames('left')
self.assertIs(spec.QC.FAIL, self.qc.check_brightness(bounds=(10, 20)))
self.assertIs(spec.QC.FAIL, self.qc.check_brightness(max_std=1e-6))
# Check outcome when no frame samples loaded
self.qc.data['frame_samples'] = None
self.assertIs(spec.QC.NOT_SET, self.qc.check_brightness())
[docs]
def test_check_framerate(self):
FPS = 60.
self.qc.data['video'] = {'fps': FPS}
self.qc.data['timestamps'] = np.array([round(1 / FPS, 4)] * 1000).cumsum()
outcome, frate = self.qc.check_framerate()
self.assertIs(spec.QC.PASS, outcome)
self.assertEqual(59.88, frate)
self.assertIs(spec.QC.FAIL, self.qc.check_framerate(threshold=1e-2)[0])
self.qc.data['timestamps'] = None
self.assertIs(spec.QC.NOT_SET, self.qc.check_framerate())
[docs]
def test_check_pin_state(self):
FPS = 60.
self.assertIs(spec.QC.NOT_SET, self.qc.check_pin_state())
# Add some dummy data
self.qc.data.timestamps = np.array([round(1 / FPS, 4)] * 5).cumsum()
self.qc.data.pin_state = np.zeros((self.qc.data.timestamps.size, 4), dtype=bool)
self.qc.data.pin_state[1:-1, -1] = True # Pulse on 4th pin
self.qc.data['video'] = {'fps': FPS, 'length': len(self.qc.data.timestamps)}
self.qc.data.audio = self.qc.data.timestamps[[0, -1]] - 10e-3
# Check passes and plots results
outcome, *_ = self.qc.check_pin_state(display=True)
self.assertIs(spec.QC.PASS, outcome)
a, b = [ln.get_xdata() for ln in plt.gcf().axes[0].lines]
self.assertEqual(a, self.qc.data.timestamps[1])
np.testing.assert_array_equal(b, self.qc.data.audio)
# Fudge some numbers
self.qc.data['video']['length'] = self.qc.data.pin_state.shape[0] - 3
self.assertIs(spec.QC.WARNING, self.qc.check_pin_state()[0])
self.qc.data['video']['length'] = 10
outcome, *dTTL = self.qc.check_pin_state()
self.assertIs(spec.QC.FAIL, outcome)
self.assertEqual([0, -5], dTTL)
[docs]
def test_check_dropped_frames(self):
n = 20
self.qc.data.count = np.arange(n)
self.qc.data.video = {'length': n}
self.assertIs(spec.QC.PASS, self.qc.check_dropped_frames()[0])
# Drop some frames
dropped = 6
self.qc.data.count = np.append(self.qc.data.count, n + dropped)
outcome, dframe, sz_diff = self.qc.check_dropped_frames()
self.assertIs(spec.QC.FAIL, outcome)
self.assertEqual(dropped, dframe)
self.assertEqual(1, sz_diff)
# Verify threshold arg; should be warning due to size diff
outcome, *_ = self.qc.check_dropped_frames(threshold=70.)
self.assertIs(spec.QC.WARNING, outcome)
# Verify critical outcome
self.qc.data.count = np.random.permutation(self.qc.data.count) # Count out of order
self.assertIs(spec.QC.CRITICAL, self.qc.check_dropped_frames([0]))
# Verify not set outcome
self.qc.data.video = None
self.assertIs(spec.QC.NOT_SET, self.qc.check_dropped_frames())
[docs]
def test_check_focus(self):
self.qc.label = 'left'
self.qc.frame_samples_idx = np.linspace(0, 100, 20, dtype=int)
outcome = self.qc.check_focus(test=True, display=True)
self.assertIs(spec.QC.PASS, outcome)
# Verify figures
figs = plt.get_fignums()
self.assertEqual(len(plt.figure(figs[0]).axes), 16)
# Verify Laplacian on blurred images
expected = np.array([11.82, 12.94, 13.84, 14.52, 15.68, 16.76, 18.85, 21.9,
25.45, 31.3, 40.48, 54.05, 81.53, 133.19, 425.15, 425.15])
actual = [round(x, 2) for x in plt.figure(figs[1]).axes[3].lines[0]._y.tolist()]
np.testing.assert_array_equal(expected, actual)
# Verify fft on blurred images
expected = np.array([6.91, 7.2, 7.61, 8.08, 8.76, 9.47, 10.35, 11.22,
11.04, 11.42, 11.35, 11.94, 12.45, 13.22, 13.6, 13.6])
actual = [round(x, 2) for x in plt.figure(figs[2]).axes[3].lines[0]._y.tolist()]
np.testing.assert_array_almost_equal(expected, actual, 1)
# Verify not set outcome
outcome = self.qc.check_focus()
self.assertIs(spec.QC.NOT_SET, outcome)
# Verify ROI
self.qc.data.frame_samples = self.qc.load_reference_frames('left')
outcome = self.qc.check_focus(roi=None)
self.assertIs(spec.QC.PASS, outcome)
[docs]
def test_check_position(self):
# Verify test mode
outcome = self.qc.check_position(test=True, display=True)
self.assertIs(spec.QC.PASS, outcome)
# Verify plots
axes = plt.gcf().axes
self.assertEqual(3, len(axes))
expected = np.array([100., 93.74829841, 93.2494463])
np.testing.assert_almost_equal(axes[2].lines[0]._y, expected)
# Verify not set (no frame samples and not in test mode)
outcome = self.qc.check_position()
self.assertIs(spec.QC.NOT_SET, outcome)
# Verify percent threshold as False
thresh = (75, 80)
outcome = self.qc.check_position(test=True, pct_thresh=False,
hist_thresh=thresh, display=True)
self.assertIs(spec.QC.FAIL, outcome)
fig = plt.get_fignums()[-1]
thr = [ln._y[0] for ln in plt.figure(fig).axes[2].lines[1:]]
self.assertCountEqual(thr, thresh, 'unexpected thresholds in figure')
[docs]
def test_check_resolution(self):
self.qc.data['video'] = {'width': 1280, 'height': 1024}
self.assertIs(spec.QC.PASS, self.qc.check_resolution())
self.qc.data['video']['width'] = 150
self.assertIs(spec.QC.FAIL, self.qc.check_resolution())
self.qc.data['video'] = None
self.assertIs(spec.QC.NOT_SET, self.qc.check_resolution())
[docs]
def test_check_timestamps(self):
FPS = 60.
n = 1000
self.qc.data['video'] = Bunch({'fps': FPS, 'length': n})
self.qc.data['timestamps'] = np.array([round(1 / FPS, 4)] * n).cumsum()
# Verify passes
self.assertIs(spec.QC.PASS, self.qc.check_timestamps())
# Verify fails
self.qc.data['timestamps'] = np.array([round(1 / 30, 4)] * 100).cumsum()
self.assertIs(spec.QC.FAIL, self.qc.check_timestamps())
# Verify not set
self.qc.data['video'] = None
self.assertIs(spec.QC.NOT_SET, self.qc.check_timestamps())
[docs]
def test_check_camera_times(self):
outcome = self.qc.check_camera_times()
self.assertIs(spec.QC.NOT_SET, outcome)
# Verify passes
self.qc.label = 'body'
ts_path = Path(__file__).parents[1].joinpath('extractors', 'data', 'session_ephys')
ssv_times = load_camera_ssv_times(ts_path, self.qc.label)
self.qc.data.bonsai_times, self.qc.data.camera_times = ssv_times
self.qc.data.video = Bunch({'length': self.qc.data.bonsai_times.size})
outcome, _ = self.qc.check_camera_times()
self.assertIs(spec.QC.PASS, outcome)
# Verify warning
n_over = 14
self.qc.data.video['length'] -= n_over
outcome, actual = self.qc.check_camera_times()
self.assertIs(spec.QC.WARNING, outcome)
self.assertEqual(n_over, actual)
[docs]
def test_check_wheel_alignment(self):
"""This just checks data validation. Integration tests test the MotionAlignment class"""
outcome = self.qc.check_wheel_alignment()
self.assertIs(spec.QC.NOT_SET, outcome)
# Expect FAIL when no overlapping timestamps between wheel and camera
self.qc.data['wheel'] = {
'timestamps': np.arange(4000),
'position': np.random.random(4000),
'period': np.array([3000, 3050])
}
self.qc.data['timestamps'] = np.arange(5000, 6000)
outcome = self.qc.check_wheel_alignment()
self.assertIs(spec.QC.FAIL, outcome)
# Expect NOT_SET when some overlapping timestamps but chosen period out of range
self.qc.data['timestamps'] -= 1500
with self.assertLogs(logging.getLogger('ibllib'), logging.WARNING):
outcome = self.qc.check_wheel_alignment()
self.assertIs(spec.QC.NOT_SET, outcome)
[docs]
def test_get_active_wheel_period(self):
"""Check that warning is raised, period is returned None, and QC is NOT_SET
if there is active wheel period to be found"""
wheel_keys = ('timestamps', 'position')
wheel_data = (np.arange(1000), np.ones(1000))
self.qc.data['wheel'] = Bunch(zip(wheel_keys, wheel_data))
with self.assertLogs(logging.getLogger('ibllib'), logging.WARNING):
period = self.qc.get_active_wheel_period(self.qc.data['wheel'])
self.assertEqual(None, period)
outcome = self.qc.check_wheel_alignment()
self.assertIs(spec.QC.NOT_SET, outcome)
[docs]
def test_get_task_collection(self):
"""Test for ibllib.qc.camera.get_task_collection"""
params = {'version': '1.0.0', 'tasks': [
{'passiveChoiceWorld': {'collection': 'raw_task_data_00'}},
{'ephysChoiceWorld': {'collection': 'raw_task_data_01'}}
]}
self.assertEqual('raw_behavior_data', get_task_collection(None))
self.assertEqual('raw_task_data_01', get_task_collection(params))
if __name__ == '__main__':
unittest.main(exit=False, verbosity=2)