import asyncio
import re
import json
import socket
import warnings
import logging
from asyncio import isfuture
from abc import ABC, abstractmethod
from urllib.parse import urlparse
import urllib.request
import ipaddress
from enum import IntEnum
LISTEN_PORT = 10001 # listen for commands on this port
[docs]
def external_ip():
"""
Fetch WAN IP address.
NB: Requires internet.
Returns
-------
ipaddress.IPv4Address, ipaddress.IPv6Address
The computer's default WAN IP address.
"""
return ipaddress.ip_address(urllib.request.urlopen('https://ident.me').read().decode('utf8'))
[docs]
def is_valid_ip(ip_address) -> bool:
"""
Test whether IP address is valid.
Parameters
----------
ip_address : str
An IP address to validate.
Returns
-------
bool
True is IP address is valid.
"""
try:
ipaddress.ip_address(ip_address)
return True
except ValueError:
return False
[docs]
def hostname2ip(hostname=None):
"""
Resolve hostname to IP address.
Parameters
----------
hostname : str, optional
The hostname to resolve. If None, resolved this computer's hostname.
Returns
-------
ipaddress.IPv4Address, ipaddress.IPv6Address
The resolved IP address.
Raises
------
ValueError
Failed to resolve IP for hostname.
"""
hostname = hostname or socket.gethostname()
try:
ip_address = socket.gethostbyname(hostname)
return ipaddress.ip_address(ip_address)
except (socket.error, socket.gaierror):
raise ValueError(f'Failed to resolve IP for hostname "{hostname}"')
[docs]
def validate_uri(uri, resolve_host=True, default_port=LISTEN_PORT, default_proc='udp'):
"""
Ensure URI is complete and correct.
Parameters
----------
uri : str, ipaddress.IPv4Address, ipaddress.IPv6Address
A full URI, hostname or hostname and port.
resolve_host : bool
If the URI is not an IP address, attempt to resolve hostname to IP.
default_port : int, str
If the port is absent from the URI, append this one.
default_proc : str
If the URI scheme is missing, prepend this one.
Returns
-------
str
The complete URI.
Raises
------
TypeError
URI type not supported.
ValueError
Failed to resolve host name to IP address.
URI host contains invalid characters (expects only alphanumeric + hyphen).
Port number not within range (must be > 1, <= 65535).
"""
# Validate URI scheme
if not isinstance(uri, (str, ipaddress.IPv4Address, ipaddress.IPv6Address)):
raise TypeError(f'Unsupported URI "{uri}" of type {type(uri)}')
if isinstance(uri, str) and (proc := re.match(r'(?P<proc>^[a-zA-Z]+(?=://))', uri)):
proc = proc.group()
uri = uri[len(proc) + 3:]
else:
proc = default_proc
# Validate hostname
if isinstance(uri, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
host = str(uri)
port = default_port
elif ':' in uri:
host, port = uri.split(':', 1)
else:
host = uri
port = None
if isinstance(uri, str) and not is_valid_ip(host):
if resolve_host:
host = hostname2ip(host)
elif not re.match(r'^[a-z0-9-]+$', host):
raise ValueError(f'Invalid hostname "{host}"')
# Validate port
try:
port = int(port or default_port)
assert 1 <= port <= 65535
except (AssertionError, ValueError):
raise ValueError(f'Invalid port number: {port or default_port}')
return f'{proc or default_proc}://{host}:{port}'
[docs]
class ExpMessage(IntEnum):
"""A set of standard experiment messages for communicating between rigs."""
"""Experiment is initializing."""
EXPINIT = 10
"""Experiment has begun."""
EXPSTART = 20
"""Experiment has stopped."""
EXPEND = 30
"""Experiment cleanup begun."""
EXPCLEANUP = 40
"""Experiment interrupted."""
EXPINTERRUPT = 50
"""Experiment status."""
EXPSTATUS = 1
"""Experiment info, including task protocol start and end."""
EXPINFO = 2
"""Alyx token."""
ALYX = 3
[docs]
@staticmethod
def validate(event):
"""
Validate an event message, returning an corresponding enumeration if valid and raising an
exception if not.
Parameters
----------
event : str, int, ExpMessage
An event message to validate.
Returns
-------
ExpMessage:
The corresponding event enumeration.
Raises
------
TypeError
event is neither a string, integer nor enumeration.
ValueError
event does not correspond to any ExpMessage enumeration, neither in its integer form
nor its string name.
Examples
--------
>>> ExpMessage.validate('expstart')
ExpMessage.EXPSTART
>>> ExpMessage.validate(10)
ExpMessage.EXPINIT
>>> ExpMessage.validate(ExpMessage.EXPEND)
ExpMessage.EXPEND
"""
if not isinstance(event, ExpMessage):
try:
if isinstance(event, str):
event = ExpMessage[event.strip().upper()]
elif isinstance(event, int):
event = ExpMessage(event)
else:
raise TypeError(f'Unknown event type {type(event)}')
except KeyError:
raise ValueError(f'Unrecognized event "{event}". '
f'Choices: {tuple(map(lambda item: item.name, ExpMessage))}')
return event
[docs]
class Service(ABC):
"""An abstract base class for auxiliary experiment services."""
__slots__ = 'name'
[docs]
@abstractmethod
def init(self, data=None):
"""
Initialize an experiment.
This is intended to specify the expected message signature. The subclassed method should
serialize the returned values and pass them to the transport layer.
Parameters
----------
data : any
Optional extra data to send to the remote server.
Returns
-------
ExpMessage.EXPINIT
The EXPINIT event.
any, None
Optional extra data.
"""
return ExpMessage.EXPINIT, data
[docs]
@abstractmethod
def start(self, exp_ref, data=None):
"""
Start an experiment.
This is intended to specify the expected message signature. The subclassed method should
serialize the returned values and pass them to the transport layer.
Parameters
----------
exp_ref : str
An experiment reference string in the form yyyy-mm-dd_n_subject.
data : any
Optional extra data to send to the remote server.
Returns
-------
ExpMessage.EXPSTART
The EXPSTART event.
str
The experiment reference string.
any, None
Optional extra data.
"""
return ExpMessage.EXPSTART, exp_ref, data
[docs]
@abstractmethod
def stop(self, data=None, immediately=False):
"""
Stop an experiment.
This is intended to specify the expected message signature. The subclassed method should
serialize the returned values and pass them to the transport layer.
Parameters
----------
data : any
Optional extra data to send to the remote server.
immediately : bool
If True, an EXPINTERRUPT message is returned.
Returns
-------
ExpMessage.EXPINTERRUPT, ExpMessage.EXPEND
The EXPEND event, or EXPINTERRUPT if immediately is True.
any, None
Optional extra data.
"""
return ExpMessage.EXPINTERRUPT if immediately else ExpMessage.EXPEND, data
[docs]
@abstractmethod
def cleanup(self, data=None):
"""
Clean up an experiment.
This is intended to specify the expected message signature. The subclassed method should
serialize the returned values and pass them to the transport layer.
Parameters
----------
data : any
Optional extra data to send to the remote server.
Returns
-------
ExpMessage.EXPCLEANUP
The EXPCLEANUP event.
any, None
Optional extra data.
"""
return ExpMessage.EXPCLEANUP, data
[docs]
@abstractmethod
def alyx(self, alyx):
"""
Request/send Alyx token.
This is intended to specify the expected message signature. The subclassed method should
serialize the returned values and pass them to the transport layer.
Parameters
----------
alyx : one.webclient.AlyxClient
Optional instance of Alyx to send.
Returns
-------
ExpMessage.ALYX
The ALYX event.
str
The Alyx database URL.
dict
The Alyx token in the form {user: token}.
"""
return ExpMessage.ALYX, alyx.base_url, alyx._par.TOKEN
[docs]
class Communicator(Service):
"""A base class for communicating between experimental rigs.
Attributes
----------
name : str
An arbitrary label for the remote host
server_uri : str
The full URI of the remote device, e.g. udp://192.168.0.1:1001
"""
__slots__ = ('server_uri', '_callbacks', 'logger')
def __init__(self, server_uri, name=None, logger=None):
self.server_uri = validate_uri(server_uri)
self.name = name or server_uri
self.logger = logger or logging.getLogger(self.name)
self._callbacks = dict(map(lambda item: (item, []), ExpMessage))
[docs]
def assign_callback(self, event, callback):
"""
Assign a callback to be called when an event occurs.
NB: Unlike with futures, an assigned callback may be triggered multiple times, whereas
coroutines may only be set once after which they are cleared.
Parameters
----------
event : str, int, iblutil.io.net.base.ExpMessage
The event for which the callback is registered.
callback : function, asyncio.Future
A function or Future to trigger when an event occurs.
See Also
--------
EchoProtocol.receive
The method that processes the callbacks upon receiving a message.
"""
event = ExpMessage.validate(event)
if not (callable(callback) or isfuture(callback)):
raise TypeError('Callback must be callable or a Future')
self._callbacks.setdefault(event, []).append(callback)
[docs]
def clear_callbacks(self, event, callback=None):
"""
For a given event, remove the provided callback, or all callbacks if none were provided.
Parameters
----------
event : str, int, iblutil.io.net.base.ExpMessage
The event for which the callback was registered.
callback : function, asyncio.Future
The callback or future to remove.
Returns
-------
int
The number of callbacks removed.
"""
i = 0
event = ExpMessage.validate(event)
if callback: # clear specific callback
# Wrapped callables have an id attribute containing the hash of the inner function
ids = [getattr(cb, 'id', None) or hash(cb) for cb in self._callbacks[event]]
cb_id = getattr(callback, 'id', None) or hash(callback)
while True:
try:
idx = ids.index(cb_id)
self._callbacks[event].pop(idx)
ids.pop(idx)
i += 1
except (IndexError, ValueError):
break
else: # clear all callbacks for event
i = len(self._callbacks[event])
self._callbacks[event] = []
self.logger.debug('[%s] %i callbacks cleared', self.name, i)
return i
[docs]
async def on_event(self, event):
"""
Await an event from the remote host.
Parameters
----------
event : str, int, iblutil.io.net.base.ExpMessage
The event to wait on.
Returns
-------
any
The response data returned by the remote host.
Examples
--------
>>> data = await com.on_event('EXPSTART')
>>> event = await asyncio.create_task(com.on_event('EXPSTART'))
>>> ...
>>> data = await event
"""
loop = asyncio.get_running_loop()
fut = loop.create_future()
self.assign_callback(event, fut)
return await fut
@property
def port(self) -> int:
"""int: the remote port"""
return urlparse(self.server_uri).port
@property
def hostname(self) -> str:
"""str: the remote hostname or IP address"""
return urlparse(self.server_uri).hostname
@property
def protocol(self) -> str:
"""str: the protocol scheme, e.g. udp, ws"""
return urlparse(self.server_uri).scheme
@property
@abstractmethod
def is_connected(self) -> bool:
"""bool: True if the remote device is connected"""
pass
[docs]
@abstractmethod
def send(self, data, addr=None):
"""Serialize and pass data to the transport layer"""
pass
def _receive(self, data, addr):
"""
Process data received from remote host and notify event listeners.
This is called by the transport layer when a message is received and should not be called
by the user.
Parameters
----------
data : bytes
The serialized data received by the transport layer.
addr : (str, int)
The source address as (hostname, port)
Warnings
--------
Warnings
Expects the deserialized data to be a tuple where the first element is an ExpMessage.
A warning is thrown if the data is not a tuple, or has fewer than two elements.
TODO Perhaps for every event only the first future should be set.
"""
data = self.decode(data)
if isinstance(data, (list, tuple)) and len(data) > 1:
event, *data = data
for f in self._callbacks[ExpMessage.validate(event)].copy():
if isfuture(f):
if not f.cancelled():
f.set_result((data, addr))
self.clear_callbacks(event, f) # Remove future from list
else:
try:
f(data, addr)
except Exception as ex:
self.logger.error('Callback "%s" failed with error "%s"', f, ex)
message = self.encode([ExpMessage.EXPINTERRUPT, f'{type(ex).__name__}: {ex}'])
self.send(message, addr) # FIXME should be confirmed send
break
else:
warnings.warn(f'Expected list, got {data}', RuntimeWarning)
[docs]
@staticmethod
def encode(data) -> bytes:
"""
Serialize data for transmission.
None-string or -bytes objects are encoded as JSON before converting to bytes.
Parameters
----------
data : any
The data to serialize.
Returns
-------
bytes
The encoded data.
"""
if isinstance(data, bytes):
return data
if not isinstance(data, str):
data = json.dumps(data)
return data.encode()
[docs]
@staticmethod
def decode(data: bytes):
"""
De-serialize and parse bytes data.
This function attempts to decode the data as a JSON object. On failing that, returns a
string.
Parameters
----------
data : bytes
The data to de-serialize.
Returns
-------
any
Deserialized data.
"""
try:
data = json.loads(data)
except json.JSONDecodeError:
warnings.warn('Failed to decode as JSON')
data = data.decode()
return data
[docs]
def close(self):
"""De-register all callbacks and cancel futures"""
for event in self._callbacks:
for fut in filter(isfuture, self._callbacks[event]):
fut.cancel('Close called on communicator')
self._callbacks[event] = []