from pathlib import Path
import tempfile
import unittest
import uuid
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from urllib.parse import urlparse
from one.api import ONE
from one.webclient import http_download_file
import ibllib.plots.misc
from ibllib.tests import TEST_DB
from ibllib.tests.fixtures.utils import register_new_session
from ibllib.plots.snapshot import Snapshot
from ibllib.plots.figures import dlc_qc_plot
WIDTH, HEIGHT = 1000, 100
[docs]
class TestSnapshot(unittest.TestCase):
[docs]
@classmethod
def setUpClass(cls):
# Make a small image and store in tmp file
cls.tmp_dir = tempfile.TemporaryDirectory()
cls.img_file = Path(cls.tmp_dir.name).joinpath('test.png')
image = Image.new('RGBA', size=(WIDTH, HEIGHT), color=(155, 0, 0))
image.save(cls.img_file, 'png')
image.close()
# set up ONE
cls.one = ONE(**TEST_DB)
# Collect all notes to delete them later
cls.notes = []
# make a new test session
_, eid = register_new_session(cls.one, subject='ZM_1150')
cls.eid = str(eid)
def _get_image(self, url):
# This is a bit of a hack because when running the server locally, the request to the media folder fails
rel_path = urlparse(url).path[1:]
try:
img_file = list(Path('/var/www/').rglob(rel_path))[0]
except IndexError:
img_file = http_download_file(url, target_dir=Path(self.tmp_dir.name), username=TEST_DB['username'],
password=TEST_DB['password'], silent=True)
return img_file
[docs]
def test_class_setup(self):
# Tests that the creation of the class works and that the defaults are working
object_id = str(uuid.uuid4())
snp = Snapshot(object_id, one=self.one)
self.assertEqual(snp.object_id, object_id)
self.assertEqual(snp.content_type, 'session') # default
self.assertTrue(len(snp.images) == 0)
# Test with different content type
snp = Snapshot(object_id, content_type='probeinsertion', one=self.one)
self.assertEqual(snp.object_id, object_id)
self.assertEqual(snp.content_type, 'probeinsertion')
self.assertTrue(len(snp.images) == 0)
[docs]
def test_snapshot_default(self):
# Test a case where object_id and content type match
object_id = self.one.alyx.rest('subjects', 'list', limit=1)[0]['id']
snp = Snapshot(object_id, content_type='subject', one=self.one)
with self.assertLogs('ibllib', 'INFO'):
self.notes.append(snp.register_image(self.img_file, text='default size'))
# Check image size is scaled to default width (defined in alyx settings.py)
img_db = self._get_image(self.notes[-1]['image'])
with Image.open(img_db) as im:
self.assertEqual(im.size, (800, HEIGHT * 800 / WIDTH))
# Test a case where they don't match
snp = Snapshot(str(uuid.uuid4()), content_type='session', one=self.one)
with self.assertLogs('ibllib', 'ERROR'):
note = snp.register_image(self.img_file, text='default size')
self.assertIsNone(note)
[docs]
def test_image_scaling(self):
# make a new session
object_id = self.eid
snp = Snapshot(object_id, content_type='session', one=self.one)
# Image in original size
self.notes.append(snp.register_image(self.img_file, text='original size', width='orig'))
img_db = self._get_image(self.notes[-1]['image'])
with Image.open(img_db) as im:
self.assertEqual(im.size, (WIDTH, HEIGHT))
# Scale to width 100
self.notes.append(snp.register_image(self.img_file, text='original size', width=100))
img_db = self._get_image(self.notes[-1]['image'])
with Image.open(img_db) as im:
self.assertEqual(im.size, (100, HEIGHT * 100 / WIDTH))
[docs]
def test_register_multiple(self):
expected_texts = ['first', 'second', 'third']
expected_sizes = [(800, HEIGHT * 800 / WIDTH), (WIDTH, HEIGHT), (200, HEIGHT * 200 / WIDTH)]
object_id = self.one.alyx.rest('datasets', 'list', limit=1)[0]['url'][-36:]
snp = Snapshot(object_id, content_type='dataset', one=self.one)
# Register multiple figures by giving a list
self.notes.extend(snp.register_images([self.img_file, self.img_file, self.img_file],
texts=['first', 'second', 'third'], widths=[None, 'orig', 200]))
for i in range(3):
self.assertEqual(self.notes[i - 3]['text'], expected_texts[i])
img_db = self._get_image(self.notes[i - 3]['image'])
with Image.open(img_db) as im:
self.assertEqual(im.size, expected_sizes[i])
# Registering multiple figures by adding to self.figures
self.assertEqual(len(snp.images), 0)
with self.assertLogs('ibllib', 'WARNING'):
out = snp.register_images()
self.assertIsNone(out)
snp.images.extend([self.img_file, self.img_file, self.img_file])
self.notes.extend(snp.register_images(texts=['always the same'], widths=[200]))
for i in range(3):
self.assertEqual(self.notes[i - 3]['text'], 'always the same')
img_db = self._get_image(self.notes[i - 3]['image'])
with Image.open(img_db) as im:
self.assertEqual(im.size, expected_sizes[2])
[docs]
def test_generate_image(self):
snp = Snapshot(str(uuid.uuid4()), one=self.one)
def make_img(size, out_path):
image = Image.new('RGBA', size=size, color=(100, 100, 100))
image.save(out_path, 'png')
return out_path
out_path = Path(self.tmp_dir.name).joinpath('test_generate.png')
snp.generate_image(make_img, {'size': (WIDTH, HEIGHT), 'out_path': out_path})
self.assertEqual(len(snp.images), 1)
self.assertEqual(snp.images[0], out_path)
with Image.open(out_path) as im:
self.assertEqual(im.size, (WIDTH, HEIGHT))
[docs]
@classmethod
def tearDownClass(cls):
# Clean up tmp dir
cls.tmp_dir.cleanup()
# Delete all notes
for note in cls.notes:
cls.one.alyx.rest('notes', 'delete', id=note['id'])
# Delete the new session that was made
cls.one.alyx.rest('sessions', 'delete', id=cls.eid)
[docs]
class TestDlcQcPlot(unittest.TestCase):
[docs]
@classmethod
def setUpClass(cls):
cls.tmp_dir = tempfile.TemporaryDirectory()
cls.one = ONE(**TEST_DB)
[docs]
@classmethod
def tearDownClass(cls):
# Clean up tmp dir
cls.tmp_dir.cleanup()
# fig = dlc_qc_plot(self.one.eid2path(eid), self.one)
# fig_path = (Path(self.tmp_dir.name).joinpath('dlc_qc_plot.png'))
# fig.savefig(fig_path)
# with Image.open(fig_path) as im:
# self.assertEqual(im.size, (1700, 1000))
[docs]
class TestMiscPlot(unittest.TestCase):
[docs]
def test_star_plot(self):
r = np.random.rand(6)
ax = ibllib.plots.misc.starplot(['a', 'b', 'c', 'd', 'e', 'f'], r, ylim=[0, 1])
r = np.random.rand(6)
ibllib.plots.misc.starplot(['a', 'b', 'c', 'd', 'e', 'f'], r, ax=ax, color='r')
plt.close('all')
[docs]
def test_wiggle(self):
w = np.random.rand(500, 40) - 0.5
ibllib.plots.misc.wiggle(w, fs=30000)
ibllib.plots.misc.Traces(w, fs=30000, color='r')
plt.close('all')