import sys
import asyncio
import logging
import unittest
from unittest import mock
import ipaddress
import socket
from packaging.version import Version
from datetime import date
from iblutil.io.net import base, app
ver = (getattr(sys.version_info, v) for v in ('major', 'minor', 'micro'))
ver = Version('.'.join(map(str, ver)))
[docs]
class TestBase(unittest.IsolatedAsyncioTestCase):
"""Test for base network utils.
NB: This requires internet access.
"""
[docs]
def test_parse_uri(self):
"""Tests for parse_uri, validate_ip and hostname2ip"""
expected = 'udp://192.168.0.1:9999'
uri = base.validate_uri(expected)
self.assertEqual(expected, uri)
self.assertEqual(expected, base.validate_uri(uri[6:]))
self.assertEqual(expected.replace('udp', 'ws'), base.validate_uri(uri[6:], default_proc='ws'))
self.assertEqual(expected, base.validate_uri(uri[:-5], default_port=9999))
uri = base.validate_uri(ipaddress.ip_address('192.168.0.1'), default_port=9999)
self.assertEqual(expected, uri)
self.assertEqual('udp://foobar:11001', base.validate_uri('foobar', resolve_host=False))
# Check IP resolved
uri = base.validate_uri('http://google.com:80', resolve_host=True)
expected = (ipaddress.IPv4Address, ipaddress.IPv6Address)
self.assertIsInstance(ipaddress.ip_address(uri[7:-3]), expected)
# Check validations
validations = {'ip': '256.168.0.0000', 'hostname': 'foo@bar$', 'port': 'foobar:00'}
for subtest, to_validate in validations.items():
with self.subTest(**{subtest: to_validate}):
with self.assertRaises(ValueError):
base.validate_uri(to_validate, resolve_host=False)
with self.assertRaises(ValueError):
base.validate_uri(' ', resolve_host=True)
with self.assertRaises(TypeError):
base.validate_uri(b'localhost')
[docs]
def test_external_ip(self):
"""Test for external_ip"""
self.assertFalse(ipaddress.ip_address(base.external_ip()).is_private)
[docs]
def test_ExpMessage(self):
"""Test for ExpMessage.validate method."""
# Check identity
msg = base.ExpMessage.validate(base.ExpMessage.EXPINFO)
self.assertIs(msg, base.ExpMessage.EXPINFO)
# Check integer input
msg = base.ExpMessage.validate(int(base.ExpMessage.EXPCLEANUP))
self.assertIs(msg, base.ExpMessage.EXPCLEANUP)
# Check string input
msg = base.ExpMessage.validate(' expstatus')
self.assertIs(msg, base.ExpMessage.EXPSTATUS)
# Check errors
with self.assertRaises(TypeError):
base.ExpMessage.validate(b'EXPSTART')
with self.assertRaises(ValueError):
base.ExpMessage.validate('EXPSTOP')
# Test allow_bitwise kwarg
event = base.ExpMessage.any()
self.assertIs(base.ExpMessage.validate(event), event)
with self.assertRaises(ValueError):
base.ExpMessage.validate(event, allow_bitwise=False)
[docs]
def test_encode(self):
"""Tests for iblutil.io.net.base.Communicator.encode"""
message = [None, 21, 'message']
encoded = base.Communicator.encode(message)
self.assertEqual(encoded, b'[null, 21, "message"]')
self.assertEqual(base.Communicator.encode(encoded), b'[null, 21, "message"]')
[docs]
def test_decode(self):
"""Tests for iblutil.io.net.base.Communicator.decode"""
data = b'[null, 21, "message"]'
decoded = base.Communicator.decode(data)
self.assertEqual(decoded, [None, 21, 'message'])
with self.assertWarns(Warning):
decoded = base.Communicator.decode(data + b'"')
self.assertEqual(decoded, '[null, 21, "message"]"')
[docs]
async def test_is_success(self):
"""Tests for iblutil.io.net.base.is_success function."""
# Expect True only when set_result has been called.
fut = asyncio.get_event_loop().create_future()
self.assertFalse(base.is_success(fut))
fut.set_result(None)
self.assertTrue(base.is_success(fut))
fut = asyncio.get_event_loop().create_future()
fut.cancel()
self.assertFalse(base.is_success(fut))
fut = asyncio.get_event_loop().create_future()
fut.set_exception(RuntimeError)
self.assertFalse(base.is_success(fut))
[docs]
@unittest.skipIf(ver < Version('3.9'), 'only version 3.9 or later supported')
class TestUDP(unittest.IsolatedAsyncioTestCase):
[docs]
async def asyncSetUp(self):
self.server = await app.EchoProtocol.server('localhost', name='server')
self.client = await app.EchoProtocol.client('localhost', name='client')
[docs]
async def test_start(self):
"""Tests confirmed send via start command."""
# Check socket type
self.assertIs(self.server._socket.type, app.socket.SOCK_DGRAM)
self.assertIs(self.client._socket.type, app.socket.SOCK_DGRAM)
spy = mock.MagicMock()
self.server.assign_callback('expstart', spy)
exp_ref = '2022-01-01_1_subject'
with self.assertLogs(self.server.logger, logging.INFO) as log:
await self.client.start(exp_ref)
self.assertIn(f'Received \'[{base.ExpMessage.EXPSTART}, "{exp_ref}", null]', log.records[-1].message)
spy.assert_called_with([exp_ref, None], (self.client._socket.getsockname()))
[docs]
async def test_callback_error(self):
"""Tests behaviour when callback raises exception."""
callback = mock.MagicMock()
callback.side_effect = ValueError('Callback failed')
self.server.assign_callback('expinit', callback)
task = asyncio.create_task(self.client.on_event(0))
with self.assertLogs(self.server.logger, logging.ERROR) as log:
base.Communicator._receive(self.server, b'[1, null]', self.client._socket.getsockname())
self.assertEqual(1, len(log.records))
self.assertIn('Callback failed', log.records[-1].message)
# Check error propagated back to client
with self.assertLogs(self.client.logger, logging.ERROR) as log:
(err, evt), _ = await task
self.assertEqual(1, len(log.records))
self.assertIn('Callback failed', log.records[-1].message)
self.assertEqual(base.ExpMessage.EXPINIT.value, evt)
self.assertIn('Callback failed', err)
# Check behaviour when future already done
fut = asyncio.get_running_loop().create_future()
self.server.assign_callback('EXPSTART', fut)
fut.set_result(True)
task = asyncio.create_task(self.client.on_event(2))
with self.assertLogs(self.server.logger, logging.WARNING) as log:
base.Communicator._receive(self.server, b'[2, null]', self.client._socket.getsockname())
self.assertEqual(1, len(log.records))
self.assertRegex(log.records[-1].getMessage(), 'Future .+ already resolved')
[docs]
async def test_on_event(self):
"""Test on_event method as well as init, stop, etc."""
# INIT
task = asyncio.create_task(self.server.on_event('expinit'))
await self.client.init(42)
actual, _ = await task
self.assertEqual([42], actual)
# CLEANUP
task = asyncio.create_task(self.server.on_event(base.ExpMessage.EXPCLEANUP))
await self.client.cleanup(8)
actual, _ = await task
self.assertEqual([8], actual)
# STOP
task = asyncio.create_task(self.server.on_event('EXPEND'))
await self.client.stop('foo')
actual, _ = await task
self.assertEqual(['foo'], actual)
# INTERRUPT
task = asyncio.create_task(self.server.on_event('expinterrupt'))
await self.client.stop('foo', immediately=True)
actual, _ = await task
self.assertEqual(['foo'], actual)
# START
task = asyncio.create_task(self.server.on_event('expstart'))
await self.client.start('2020-01-01_1_baz', {'foo': 'bar'})
actual, _ = await task
self.assertEqual(['2020-01-01_1_baz', {'foo': 'bar'}], actual)
task = asyncio.create_task(self.server.on_event('expstart'))
ref = {'subject': 'baz', 'date': date(2020, 1, 1), 'sequence': 1}
await self.client.start(ref, {'foo': 'bar'})
actual, _ = await task
self.assertEqual(['2020-01-01_1_baz', {'foo': 'bar'}], actual)
# STATUS
task = asyncio.create_task(self.server.on_event('EXPSTATUS'))
await self.client.status(base.ExpStatus.STOPPED)
actual, _ = await task
self.assertEqual([base.ExpStatus.STOPPED.value], actual)
task = asyncio.create_task(self.server.on_event('EXPSTATUS'))
await self.client.status('CONNECTED')
actual, _ = await task
self.assertEqual([base.ExpStatus.CONNECTED.value], actual)
# INFO
task = asyncio.create_task(self.server.on_event('expinfo'))
await self.client.info(base.ExpStatus.RUNNING, {'foo': 'bar'})
actual, _ = await task
self.assertEqual([base.ExpStatus.RUNNING.value, {'foo': 'bar'}], actual)
[docs]
async def test_alyx(self):
"""Test iblutil.io.net.app.EchoProtocol.alyx method."""
# Mock an AlyxClient instance that is logged in
alyx = mock.MagicMock()
alyx.is_logged_in = True
alyx.base_url = 'https://alyx.website.net'
alyx.user = 'foo'
alyx._token = {'token': '4157aa522b855239cd05f4d23d40563aa0518359'}
# When Alyx instance is logged in, expect the token to be broadcast
task = asyncio.create_task(self.server.on_event('ALYX'))
self.assertIsNone(await self.client.alyx(alyx), 'unexpected argument returned')
actual, _ = await task # wait for server process request
expected = ['https://alyx.website.net', {'foo': {'token': '4157aa522b855239cd05f4d23d40563aa0518359'}}]
self.assertEqual(expected, actual, 'client failed to send alyx token to server')
# Client should request and return a token when Alyx instance is not logged in
async def _req_callback():
data, addr = await self.server.on_event('ALYX')
self.assertEqual(['https://alyx.website.net', {}], data)
await self.server.alyx(alyx, addr)
# Mock an AlyxClient instance that is not logged in
alyx_logged_out = mock.MagicMock()
alyx_logged_out.is_logged_in = False
alyx_logged_out.base_url = 'https://alyx.website.net'
alyx_logged_out.user = alyx_logged_out._token = None
task = asyncio.create_task(_req_callback())
token = await self.client.alyx(alyx_logged_out)
await task # wait for server process request
self.assertEqual(expected, token, 'failed to return requested token from server')
# Try the same thing but with None instead of an Alyx instance
async def _req_callback():
data, addr = await self.server.on_event('ALYX')
self.assertEqual([None, {}], data)
await self.server.alyx(alyx, addr)
task = asyncio.create_task(_req_callback())
token = await self.client.alyx(None)
await task # wait for server process request
self.assertEqual(expected, token)
[docs]
async def test_confirmed_send_validation(self):
"""Basic tests for iblutil.io.net.app.EchoProtocol.confirmed_send exception handling."""
# Expect to raise in server role when no address provided
with self.assertRaises(TypeError):
await self.server.confirmed_send(None)
# Expect to raise in client role when provided address does not match server URI
with self.assertRaises(ValueError):
await self.client.confirmed_send(None, addr=('localhost', self.client.port + 100))
# Expect to raise when echo timeout is 0
try:
self.client.default_echo_timeout = 0
with self.assertRaises(ValueError):
await self.client.confirmed_send(None)
finally:
self.client.default_echo_timeout = app.EchoProtocol.default_echo_timeout
# Expect timeout arg to override default echo timeout, expect error raised on timeout
assert self.client.is_connected
with mock.patch('iblutil.io.net.app.asyncio.wait_for', side_effect=asyncio.TimeoutError) as m, \
self.assertRaises(TimeoutError):
await self.client.confirmed_send(None, timeout=0.2)
self.assertFalse(self.client.is_connected, 'failed to close communicator on echo timeout error')
m.assert_awaited_once_with(mock.ANY, timeout=0.2)
# Expect to raise RuntimeError with explanation when messages don't match
with mock.patch('iblutil.io.net.app.asyncio.wait_for', side_effect=RuntimeError), \
self.assertRaises(RuntimeError) as cm:
await self.client.confirmed_send(None)
self.assertIn('unexpected response', str(cm.exception).lower())
[docs]
def test_communicator(self):
"""Basic tests for iblutil.io.net.app.EchoProtocol, namely the role setter."""
# Check role validation
self.assertEqual(self.server.role, 'server')
self.assertEqual(self.client.role, 'client')
with self.assertRaises(ValueError):
app.EchoProtocol('localhost', 'foo')
with self.assertRaises(AttributeError):
self.client.role = 'foo'
[docs]
async def test_receive_validation(self):
"""Test for behaviour when non-standard message received."""
with self.assertWarns(RuntimeWarning), mock.patch.object(self.client, 'send'):
self.client._receive(b'foo', (self.server.hostname, self.server.port))
addr = (self.server.hostname, self.server.port)
fut = asyncio.get_running_loop().create_future()
self.client._last_sent[addr] = (b'foo', fut)
with self.assertLogs(self.client.name, logging.ERROR):
self.client._receive(b'bar', addr)
self.assertIsInstance(fut.exception(), RuntimeError)
# Upon receiving message from unknown host, should log warning and return
with self.assertLogs(self.client.name, logging.WARNING), \
mock.patch.object(self.client, '_receive') as receive_mock:
self.client.datagram_received(b'foo', ('192.168.0.0', self.server.port))
receive_mock.assert_not_called()
[docs]
def test_connection_made_validation(self):
"""Test for connection_made method"""
transport = mock.MagicMock()
with self.assertRaises(RuntimeError):
self.client.connection_made(transport)
transport.get_extra_info().type = socket.SOCK_STREAM
with self.assertRaises(RuntimeError):
self.client.connection_made(transport)
[docs]
async def test_awaiting_response(self):
self.assertFalse(self.client.awaiting_response())
fut = asyncio.get_running_loop().create_future()
self.client._last_sent[(self.server.hostname, self.server.port)] = (b'foo', fut)
self.assertTrue(self.client.awaiting_response())
self.assertFalse(self.client.awaiting_response(addr=('localhost', 8080)))
fut.cancel()
self.assertFalse(self.client.awaiting_response())
[docs]
async def test_close(self):
"""Test for close/cleanup routine."""
self.assertTrue(self.client.is_connected)
loop = asyncio.get_running_loop()
event_fut = loop.create_future()
self.client.assign_callback('EXPCLEANUP', event_fut)
echo_fut = loop.create_future()
addr = (self.server.hostname, self.server.port)
self.client._last_sent[addr] = (None, echo_fut)
self.client.close()
self.assertFalse(self.client.is_connected)
self.assertTrue(event_fut.cancelled())
self.assertTrue(echo_fut.cancelled())
self.assertFalse(any(self.client._callbacks.values()))
self.assertTrue(self.client._transport.is_closing())
self.assertEqual('Close called on communicator', await self.client.on_connection_lost)
self.assertTrue(self.client.on_eof_received.cancelled())
self.assertTrue(self.client.on_error_received.cancelled())
# self.assertEqual(-1, self.client._socket.fileno()) # Closed later on in loop
[docs]
async def test_on_error_received(self):
"""Test for on_error_received callback."""
ex = ValueError('foo')
with self.assertLogs(self.client.name, logging.ERROR):
self.client.error_received(ex)
self.assertTrue(self.client.on_error_received.done())
self.assertEqual(ex, self.client.on_error_received.result())
[docs]
def tearDown(self):
self.client.close()
self.server.close()
[docs]
@unittest.skipIf(ver < Version('3.9'), 'only version 3.9 or later supported')
class TestWebSockets(unittest.IsolatedAsyncioTestCase):
"""Test net.app.EchoProtocol with a TCP/IP transport layer"""
[docs]
async def asyncSetUp(self):
self.server = await app.EchoProtocol.server('ws://localhost:8888', name='server')
self.client = await app.EchoProtocol.client('ws://localhost:8888', name='client')
[docs]
async def test_start(self):
"""Tests confirmed send via start command."""
# Check socket indeed TCP
self.assertIs(self.server._socket.type, app.socket.SOCK_STREAM)
self.assertIs(self.client._socket.type, app.socket.SOCK_STREAM)
spy = mock.MagicMock()
self.server.assign_callback('expstart', spy)
exp_ref = '2022-01-01_1_subject'
with self.assertLogs(self.server.logger, logging.INFO) as log:
await self.client.start(exp_ref)
self.assertIn(f'Received \'[{base.ExpMessage.EXPSTART}, "{exp_ref}", null]', log.records[-1].message)
spy.assert_called_with([exp_ref, None], (self.client._socket.getsockname()))
[docs]
def test_send_validation(self):
"""Test for Communicator.send method."""
message = b'foo'
with mock.patch.object(self.client, '_transport') as transport:
self.client.send(message)
transport.write.assert_called_with(message)
transport.write.reset_mock()
# Check returns when external address used
with self.assertLogs(self.client.name, logging.WARNING):
self.client.send(message, addr=('192.168.0.0', 0))
transport.write.assert_not_called()
[docs]
def test_connection_made_validation(self):
"""Test for connection_made method"""
transport = mock.MagicMock()
transport.get_extra_info().type = socket.SOCK_DGRAM
with self.assertRaises(RuntimeError):
self.client.connection_made(transport)
[docs]
def tearDown(self):
self.client.close()
self.server.close()
[docs]
@unittest.skipIf(ver < Version('3.9'), 'only version 3.9 or later supported')
class TestServices(unittest.IsolatedAsyncioTestCase):
"""Tests for the app.Services class"""
[docs]
async def asyncSetUp(self):
# On each acquisition PC
self.server_1 = await app.EchoProtocol.server('localhost', name='server')
# On main experiment PC
self.client_1 = await app.EchoProtocol.client('localhost', name='client1')
self.client_2 = await app.EchoProtocol.client('localhost', name='client2')
# For some tests we'll need multiple servers (avoids having to run on multiple threads)
self.server_2 = await app.EchoProtocol.server('localhost:10002', name='server2')
self.client_3 = await app.EchoProtocol.client('localhost:10002', name='client3')
[docs]
async def test_type(self):
"""Test that services are immutable"""
services = app.Services([self.client_1, self.client_2])
# Ensure our services stack is immutable
with self.assertRaises(TypeError):
services['client2'] = app.EchoProtocol
with self.assertRaises(TypeError):
services.pop('client1')
# Ensure inputs are validated
with self.assertRaises(TypeError):
app.Services([self.client_1, None])
[docs]
async def test_close(self):
"""Test Services.close method"""
clients = [self.client_1, self.client_2]
assert all(x.is_connected for x in clients)
services = app.Services(clients)
self.assertTrue(services.is_connected)
services.close()
self.assertFalse(services.is_connected)
self.assertTrue(not any(x.is_connected for x in clients))
[docs]
async def test_assign(self):
"""Tests for Services.assign_callback and Services.clear_callbacks"""
# Assign a callback for an event
callback = mock.MagicMock(spec_set=True)
clients = (self.client_1, self.client_2)
services = app.Services(clients)
services.assign_callback('EXPINIT', callback)
for addr in map(lambda x: x._socket.getsockname(), clients):
await self.server_1.init('foo', addr=addr)
self.assertEqual(2, callback.call_count)
callback.assert_called_with(['foo'], ('127.0.0.1', 11001))
# Check return_service arg
callback2 = mock.MagicMock(spec_set=True)
services.assign_callback('EXPINIT', callback2, return_service=True)
for addr in map(lambda x: x._socket.getsockname(), clients):
await self.server_1.init('foo', addr=addr)
self.assertEqual(2, callback2.call_count)
callback2.assert_called_with(['foo'], ('127.0.0.1', 11001), self.client_2)
# Check validation
with self.assertRaises(TypeError):
services.assign_callback('EXPEND', 'foo')
# Check clear callbacks
services.assign_callback('EXPINIT', callback2)
removed = services.clear_callbacks('EXPINIT', callback)
self.assertEqual({'client1': 1, 'client2': 1}, removed)
removed = services.clear_callbacks('EXPINIT')
self.assertEqual({'client1': 2, 'client2': 2}, removed)
# Check futures cancelled
fut = asyncio.get_running_loop().create_future()
services.assign_callback('EXPINIT', fut)
assert not fut.cancelled()
services.clear_callbacks('EXPINIT')
self.assertTrue(fut.cancelled())
[docs]
async def test_init(self):
"""Test init of services.
Unfortunately this test is convoluted due to the client and server being on the same
machine.
"""
clients = (self.client_1, self.client_3)
# Require two servers as we'll need two callbacks
servers = (self.server_1, self.server_2)
# Set up the client response callbacks that the server (Services object) will await.
async def respond(server, fut):
"""Response callback for the server"""
data, addr = await fut
await asyncio.sleep(.1) # FIXME Should be able to somehow use loop.call_soon
await server.init(42, addr)
for server in servers:
asyncio.create_task(respond(server, server.on_event(base.ExpMessage.EXPINIT)))
# Create the services and initialize them, awaiting the callbacks we just set up
services = app.Services(clients)
responses = await services.init('foo')
# Test outcomes
self.assertFalse(any(map(asyncio.isfuture, responses.values())))
for name, value in responses.items():
with self.subTest(client=name):
self.assertEqual([42], value)
# Add back the callbacks to test sequential init
for server in servers:
asyncio.create_task(respond(server, server.on_event(base.ExpMessage.EXPINIT)))
# Initialize services sequentially, awaiting the callbacks we just set up
responses = await services.init('foo', concurrent=False)
# Test outcomes
self.assertFalse(any(map(asyncio.isfuture, responses.values())))
for name, value in responses.items():
with self.subTest(client=name):
self.assertEqual([42], value)
[docs]
async def test_service_methods(self):
"""Test start, stop, etc. methods.
For a more complete test, see test_init.
"""
clients = [mock.AsyncMock(spec=app.EchoProtocol), mock.AsyncMock(spec=app.EchoProtocol)]
services = app.Services(clients)
# Init
await services.init([42, 'foo'])
for client in clients:
client.init.assert_awaited_once_with(data=[42, 'foo'])
# Start
ref = '2020-01-01_1_subject'
await services.start(ref)
for client in clients:
client.start.assert_awaited_once_with(ref, data=None)
# Info
await services.info(base.ExpStatus.RUNNING, {'exp_ref': ref})
for client in clients:
client.info.assert_awaited_once_with(base.ExpStatus.RUNNING, data={'exp_ref': ref})
# Status
await services.status(base.ExpStatus.STOPPED)
for client in clients:
client.status.assert_awaited_once_with(base.ExpStatus.STOPPED)
# Stop
await services.stop(immediately=True)
for client in clients:
client.stop.assert_awaited_once_with(data=None, immediately=True)
# Cleanup
await services.cleanup(data=[42, 'foo'])
for client in clients:
client.cleanup.assert_awaited_once_with(data=[42, 'foo'])
# Alyx
alyx = mock.MagicMock()
await services.alyx(alyx)
for client in clients:
client.alyx.assert_awaited_once_with(alyx)
[docs]
async def test_sequential_signal(self):
"""Test for Services._signal method with concurrent=False"""
clients = [mock.AsyncMock(spec=app.EchoProtocol), mock.AsyncMock(spec=app.EchoProtocol)]
for i, client in enumerate(clients):
client.name = f'client_{i}'
client.on_event.return_value = ([i], (self.client_1.hostname, self.client_1.port))
services = app.Services(clients)
responses = await services._signal(base.ExpMessage.EXPINIT, 'init', 'foo', concurrent=False)
for client in clients:
client.init.assert_awaited_once()
self.assertEqual(responses, {'client_0': [0], 'client_1': [1]})
[docs]
def tearDown(self):
self.client_1.close()
self.client_2.close()
self.server_1.close()
self.server_2.close()
self.client_3.close()
if __name__ == '__main__':
from iblutil.util import setup_logger
setup_logger(app.__name__, level=logging.DEBUG)
unittest.main()