Source code for iblutil.io.net.base

import asyncio
import re
import json
import socket
import warnings
import logging
from asyncio import isfuture
from abc import ABC, abstractmethod
from urllib.parse import urlparse
import urllib.request
import ipaddress
from operator import or_
from functools import reduce
from enum import IntFlag, IntEnum, auto  # py3.11 STRICT

LISTEN_PORT = 11001  # listen for commands on this port
PROTOCOL_VERSION = '1.0.0'  # Versioning for ExpMessage, ExpStatus enumerations, and Service base class


[docs] def is_success(future: asyncio.Future) -> bool: """Check if future successfully resolved.""" return future.done() and not future.cancelled() and future.exception() is None
[docs] def external_ip(): """ Fetch WAN IP address. NB: Requires internet. Returns ------- ipaddress.IPv4Address, ipaddress.IPv6Address The computer's default WAN IP address. """ return ipaddress.ip_address(urllib.request.urlopen('https://ident.me').read().decode('utf8'))
[docs] def is_valid_ip(ip_address) -> bool: """ Test whether IP address is valid. Parameters ---------- ip_address : str An IP address to validate. Returns ------- bool True is IP address is valid. """ try: ipaddress.ip_address(ip_address) return True except ValueError: return False
[docs] def hostname2ip(hostname=None): """ Resolve hostname to IP address. Parameters ---------- hostname : str, optional The hostname to resolve. If None, resolved this computer's hostname. Returns ------- ipaddress.IPv4Address, ipaddress.IPv6Address The resolved IP address. Raises ------ ValueError Failed to resolve IP for hostname. """ hostname = hostname or socket.gethostname() try: ip_address = socket.gethostbyname(hostname) return ipaddress.ip_address(ip_address) except (socket.error, socket.gaierror): raise ValueError(f'Failed to resolve IP for hostname "{hostname}"')
[docs] def validate_uri(uri, resolve_host=True, default_port=LISTEN_PORT, default_proc='udp'): """ Ensure URI is complete and correct. Parameters ---------- uri : str, ipaddress.IPv4Address, ipaddress.IPv6Address A full URI, hostname or hostname and port. resolve_host : bool If the URI is not an IP address, attempt to resolve hostname to IP. default_port : int, str If the port is absent from the URI, append this one. default_proc : str If the URI scheme is missing, prepend this one. Returns ------- str The complete URI. Raises ------ TypeError URI type not supported. ValueError Failed to resolve host name to IP address. URI host contains invalid characters (expects only alphanumeric + hyphen). Port number not within range (must be > 1, <= 65535). """ # Validate URI scheme if not isinstance(uri, (str, ipaddress.IPv4Address, ipaddress.IPv6Address)): raise TypeError(f'Unsupported URI "{uri}" of type {type(uri)}') if isinstance(uri, str) and (proc := re.match(r'(?P<proc>^[a-zA-Z]+(?=://))', uri)): proc = proc.group() uri = uri[len(proc) + 3:] else: proc = default_proc # Validate hostname if isinstance(uri, (ipaddress.IPv4Address, ipaddress.IPv6Address)): host = str(uri) port = default_port elif ':' in uri: host, port = uri.split(':', 1) else: host = uri port = None if isinstance(uri, str) and not is_valid_ip(host): if resolve_host: host = hostname2ip(host) elif not re.match(r'^[a-z0-9-]+$', host): raise ValueError(f'Invalid hostname "{host}"') # Validate port try: port = int(port or default_port) assert 1 <= port <= 65535 except (AssertionError, ValueError): raise ValueError(f'Invalid port number: {port or default_port}') return f'{proc or default_proc}://{host}:{port}'
# class ExpMessage(IntFlag, boundary=STRICT): # py3.11
[docs] class ExpMessage(IntFlag): """A set of standard experiment messages for communicating between rigs.""" """Experiment is initializing.""" EXPINIT = auto() """Experiment has begun.""" EXPSTART = auto() """Experiment has stopped.""" EXPEND = auto() """Experiment cleanup begun.""" EXPCLEANUP = auto() """Experiment interrupted.""" EXPINTERRUPT = auto() """Experiment status.""" EXPSTATUS = auto() """Experiment info, including task protocol start and end.""" EXPINFO = auto() """Alyx token.""" ALYX = auto() __version__ = PROTOCOL_VERSION
[docs] @classmethod def any(cls) -> 'ExpMessage': """Return enumeration comprising all possible messages. NB: This doesn't include the null ExpMessage (0), used to indicate API errors. """ return reduce(or_, cls)
[docs] @staticmethod def validate(event, allow_bitwise=True): """ Validate an event message, returning a corresponding enumeration if valid and raising an exception if not. Parameters ---------- event : str, int, ExpMessage An event message to validate. allow_bitwise : bool If false, raise if event is the result of a bitwise operation. Returns ------- ExpMessage: The corresponding event enumeration. Raises ------ TypeError event is neither a string, integer nor enumeration. ValueError event does not correspond to any ExpMessage enumeration, neither in its integer form nor its string name, or `allow_bitwise` is false and value is combination of events. Examples -------- >>> ExpMessage.validate('expstart') ExpMessage.EXPSTART >>> ExpMessage.validate(10) ExpMessage.EXPINIT >>> ExpMessage.validate(ExpMessage.EXPEND) ExpMessage.EXPEND """ if not isinstance(event, ExpMessage): try: if isinstance(event, str): event = ExpMessage[event.strip().upper()] elif isinstance(event, int): event = ExpMessage(event) else: raise TypeError(f'Unknown event type {type(event)}') except KeyError: raise ValueError(f'Unrecognized event "{event}". ' f'Choices: {tuple(ExpMessage.__members__.keys())}') if not allow_bitwise and event not in list(ExpMessage): raise ValueError('Compound (bitwise) events not permitted. ' f'Choices: {tuple(ExpMessage)}') return event
def __iter__(self): # py3.11 remove this method """Iterate over the individual bits in the enumeration. NB: This method is copied from Python 3.11 which supports iteration of Enum objects. """ num = self.value while num: b = num & (~num + 1) yield b num ^= b
[docs] class ExpStatus(IntEnum): """A set of standard statuses for communicating between rigs.""" """Service is connected.""" CONNECTED = 0 """Service is initialized.""" INITIALIZED = 10 """Service is running.""" RUNNING = 20 """Experiment has stopped.""" STOPPED = 30 __version__ = PROTOCOL_VERSION
[docs] class Service(ABC): """An abstract base class for auxiliary experiment services.""" __version__ = PROTOCOL_VERSION __slots__ = 'name'
[docs] @abstractmethod def init(self, data=None): """ Initialize an experiment. This is intended to specify the expected message signature. The subclassed method should serialize the returned values and pass them to the transport layer. Parameters ---------- data : any Optional extra data to send to the remote server. Returns ------- ExpMessage.EXPINIT The EXPINIT event. any, None Optional extra data. """ return ExpMessage.EXPINIT, data
[docs] @abstractmethod def start(self, exp_ref, data=None): """ Start an experiment. This is intended to specify the expected message signature. The subclassed method should serialize the returned values and pass them to the transport layer. Parameters ---------- exp_ref : str An experiment reference string in the form yyyy-mm-dd_n_subject. data : any Optional extra data to send to the remote server. Returns ------- ExpMessage.EXPSTART The EXPSTART event. str The experiment reference string. any, None Optional extra data. """ exp_ref = exp_ref or None if isinstance(exp_ref, dict): exp_ref = '_'.join(map(str, (exp_ref['date'], int(exp_ref['sequence']), exp_ref['subject']))) return ExpMessage.EXPSTART, exp_ref, data
[docs] @abstractmethod def stop(self, data=None, immediately=False): """ Stop an experiment. This is intended to specify the expected message signature. The subclassed method should serialize the returned values and pass them to the transport layer. Parameters ---------- data : any Optional extra data to send to the remote server. immediately : bool If True, an EXPINTERRUPT message is returned. Returns ------- ExpMessage.EXPINTERRUPT, ExpMessage.EXPEND The EXPEND event, or EXPINTERRUPT if immediately is True. any, None Optional extra data. """ return ExpMessage.EXPINTERRUPT if immediately else ExpMessage.EXPEND, data
[docs] @abstractmethod def status(self, status): """ Report experiment status. NB: This is intended to be lightweight. For more detail and custom data use the info method. This is intended to specify the expected message signature. The subclassed method should serialize the returned values and pass them to the transport layer. Parameters ---------- status : ExpStatus The experiment status enumeration. Returns ------- ExpMessage.EXPSTATUS The EXPSTATUS event. ExpStatus The validated experiment status. """ if not isinstance(status, ExpStatus): status = ExpStatus(status) if isinstance(status, int) else ExpStatus[status] return ExpMessage.EXPSTATUS, status
[docs] @abstractmethod def info(self, status, data=None): """ Report experiment information. This is intended to specify the expected message signature. The subclassed method should serialize the returned values and pass them to the transport layer. Parameters ---------- status : ExpStatus The experiment status enumeration. data : any Optional extra data to send to the remote server. Returns ------- ExpMessage.EXPINFO The EXPINFO event. ExpStatus The validated experiment status. any, None Optional extra data. """ return ExpMessage.EXPINFO, Service.status(self, status)[1], data
[docs] @abstractmethod def cleanup(self, data=None): """ Clean up an experiment. This is intended to specify the expected message signature. The subclassed method should serialize the returned values and pass them to the transport layer. Parameters ---------- data : any Optional extra data to send to the remote server. Returns ------- ExpMessage.EXPCLEANUP The EXPCLEANUP event. any, None Optional extra data. """ return ExpMessage.EXPCLEANUP, data
[docs] @abstractmethod def alyx(self, alyx): """ Request/send Alyx token. This is intended to specify the expected message signature. The subclassed method should serialize the returned values and pass them to the transport layer. Parameters ---------- alyx : one.webclient.AlyxClient Optional instance of Alyx to send. Returns ------- ExpMessage.ALYX The ALYX event. str The Alyx database URL. dict The Alyx token in the form {user: token}. """ base_url = alyx.base_url if alyx else None token = {alyx.user: alyx._token} if alyx and alyx.is_logged_in else {} return ExpMessage.ALYX, base_url, token
[docs] class Communicator(Service): """A base class for communicating between experimental rigs. Attributes ---------- name : str An arbitrary label for the remote host server_uri : str The full URI of the remote device, e.g. udp://192.168.0.1:1001 """ __slots__ = ('server_uri', '_callbacks', 'logger', 'name') def __init__(self, server_uri, name=None, logger=None): self.server_uri = validate_uri(server_uri) self.name = name or server_uri self.logger = logger or logging.getLogger(self.name) # Init callbacks map of ExpMessage -> list, including null ExpMessage for processing callback errors self._callbacks = dict(map(lambda item: (item, []), (ExpMessage(0), *ExpMessage.__members__.values())))
[docs] def assign_callback(self, event, callback): """ Assign a callback to be called when an event occurs. NB: Unlike with futures, an assigned callback may be triggered multiple times, whereas coroutines may only be set once after which they are cleared. Parameters ---------- event : str, int, iblutil.io.net.base.ExpMessage The event for which the callback is registered. callback : function, asyncio.Future A function or Future to trigger when an event occurs. See Also -------- EchoProtocol.receive The method that processes the callbacks upon receiving a message. """ event = ExpMessage.validate(event) if not (callable(callback) or isfuture(callback)): raise TypeError('Callback must be callable or a Future') if event is ExpMessage(0): self._callbacks.setdefault(event, []).append((callback, False)) else: return_event = event not in ExpMessage for e in event: # iterate over enum as bitwise ops permitted self._callbacks.setdefault(e, []).append((callback, return_event))
[docs] def clear_callbacks(self, event, callback=None, cancel_futures=True): """ For a given event, remove the provided callback, or all callbacks if none were provided. Parameters ---------- event : str, int, iblutil.io.net.base.ExpMessage The event for which the callback was registered. callback : function, asyncio.Future The callback or future to remove. cancel_futures : bool If True and callback is a Future, cancel before removing. Returns ------- int The number of callbacks removed. """ i = 0 event = ExpMessage.validate(event) if callback: # clear specific callback # Wrapped callables have an id attribute containing the hash of the inner function for evt in event: # iterate as bitwise enums permitted, e.g. ~ExpMessage.ALYX ids = [getattr(cb, 'id', None) or hash(cb) for cb, _ in self._callbacks[evt]] cb_id = getattr(callback, 'id', None) or hash(callback) while True: try: idx = ids.index(cb_id) if cancel_futures and isfuture(cb := self._callbacks[evt][idx][0]): cb.cancel() del self._callbacks[evt][idx] del ids[idx] i += 1 except (IndexError, ValueError): break else: # clear all callbacks for event for evt in event: if cancel_futures: for cb in filter(isfuture, map(lambda x: x[0], self._callbacks[evt])): cb.cancel() i += len(self._callbacks[evt]) del self._callbacks[evt][:] self.logger.debug('[%s] %i callbacks cleared', self.name, i) return i
[docs] async def on_event(self, event): """ Await an event from the remote host. Parameters ---------- event : str, int, iblutil.io.net.base.ExpMessage The event to wait on. Returns ------- any The response data returned by the remote host. Examples -------- >>> data, addr = await com.on_event('EXPSTART') >>> event = await asyncio.create_task(com.on_event('EXPSTART')) >>> ... >>> data = await event Await more than one event >>> data, addr, event = await com.on_event(ExpMessage.EXPEND | ExpMessage.EXPINTERRUPT) """ loop = asyncio.get_running_loop() fut = loop.create_future() self.assign_callback(event, fut) return await fut
@property def port(self) -> int: """int: the remote port""" return int(urlparse(self.server_uri).port) @property def hostname(self) -> str: """str: the remote hostname or IP address""" return urlparse(self.server_uri).hostname @property def protocol(self) -> str: """str: the protocol scheme, e.g. udp, ws""" return urlparse(self.server_uri).scheme @property @abstractmethod def is_connected(self) -> bool: """bool: True if the remote device is connected""" pass
[docs] @abstractmethod def send(self, data, addr=None): """Serialize and pass data to the transport layer""" pass
def _receive(self, data, addr): """ Process data received from remote host and notify event listeners. This is called by the transport layer when a message is received and should not be called by the user. Parameters ---------- data : bytes The serialized data received by the transport layer. addr : (str, int) The source address as (hostname, port) Warnings -------- Warnings Expects the deserialized data to be a tuple where the first element is an ExpMessage. A warning is thrown if the data is not a tuple, or has fewer than two elements. TODO Perhaps for every event only the first future should be set. """ data = self.decode(data) if isinstance(data, (list, tuple)) and len(data) > 1: event, *data = data event = ExpMessage.validate(event, allow_bitwise=False) if event else ExpMessage(0) if event is ExpMessage(0): # An error in the remote callback function occurred err, evt = data self.logger.error('Callback for %s on %s://%s:%i failed with %s', ExpMessage(evt).name, self.protocol, *addr, err) for f, return_event in self._callbacks[event].copy(): if isfuture(f): if f.done(): self.logger.warning('Future %s already resolved', f) elif not f.cancelled(): f.set_result((data, addr, event) if return_event else (data, addr)) self.clear_callbacks(event, f) # Remove future from list else: try: f(data, addr, event) if return_event else f(data, addr) except Exception as ex: self.logger.error('Callback "%s" failed with error "%s"', f, ex) message = self.encode([0, f'{type(ex).__name__}: {ex}', event]) self.send(message, addr) break else: warnings.warn(f'Expected list, got {data}', RuntimeWarning)
[docs] @staticmethod def encode(data) -> bytes: """ Serialize data for transmission. None-string or -bytes objects are encoded as JSON before converting to bytes. Parameters ---------- data : any The data to serialize. Returns ------- bytes The encoded data. """ if isinstance(data, bytes): return data if not isinstance(data, str): data = json.dumps(data) return data.encode()
[docs] @staticmethod def decode(data: bytes): """ De-serialize and parse bytes data. This function attempts to decode the data as a JSON object. On failing that, returns a string. Parameters ---------- data : bytes The data to de-serialize. Returns ------- any Deserialized data. """ try: data = json.loads(data) except json.JSONDecodeError: warnings.warn('Failed to decode as JSON') data = data.decode() return data
[docs] def close(self): """De-register all callbacks and cancel futures""" for event in self._callbacks: for fut in filter(isfuture, map(lambda x: x[0], self._callbacks[event])): fut.cancel('Close called on communicator') del self._callbacks[event][:]