Source code for iblutil.util

import uuid
from itertools import takewhile
from os import scandir
from pathlib import Path
import collections
import colorlog
import copy
import logging
import sys
from typing import Union, Iterable, Sequence

import numpy as np

log = logging.getLogger('__name__')

LOG_FORMAT_STR = u'%(asctime)s %(levelname)-8s %(filename)s:%(lineno)-4d %(message)s'
LOG_DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
LOG_COLORS = {
    'DEBUG': 'green',
    'INFO': 'cyan',
    'WARNING': 'bold_yellow',
    'ERROR': 'bold_red',
    'CRITICAL': 'bold_purple'}


[docs] def Listable(t): """Return a typing.Union if the input and sequence of input.""" return Union[t, Sequence[t]]
[docs] class Bunch(dict): """A subclass of dictionary with an additional dot syntax.""" def __init__(self, *args, **kwargs): super(Bunch, self).__init__(*args, **kwargs) self.__dict__ = self
[docs] def copy(self, deep=False): """Return a new Bunch instance which is a copy of the current Bunch instance. Parameters ---------- deep : bool If True perform a deep copy (see notes). By default a shallow copy is returned. Returns ------- Bunch A new copy of the Bunch. Notes ----- - A shallow copy constructs a new Bunch object and then (to the extent possible) inserts references into it to the objects found in the original. - A deep copy constructs a new Bunch and then, recursively, inserts copies into it of the objects found in the original. """ return copy.deepcopy(self) if deep else Bunch(super(Bunch, self).copy())
[docs] def save(self, npz_file, compress=False): """ Saves a npz file containing the arrays of the bunch. :param npz_file: output file :param compress: bool (False) use compression :return: None """ if compress: np.savez_compressed(npz_file, **self) else: np.savez(npz_file, **self)
[docs] @staticmethod def load(npz_file): """ Loads a npz file containing the arrays of the bunch. :param npz_file: output file :return: Bunch """ if not Path(npz_file).exists(): raise FileNotFoundError(f"{npz_file}") return Bunch(np.load(npz_file))
def _iflatten(x): result = [] for el in x: if isinstance(el, collections.abc.Iterable) and not ( isinstance(el, str) or isinstance(el, dict)): result.extend(_iflatten(el)) else: result.append(el) return result def _gflatten(x): def iselement(e): return not (isinstance(e, collections.abc.Iterable) and not ( isinstance(el, str) or isinstance(el, dict))) for el in x: if iselement(el): yield el else: yield from _gflatten(el)
[docs] def flatten(x, generator=False): """ Flatten a nested Iterable excluding strings and dicts. Converts nested Iterable into flat list. Will not iterate through strings or dicts. :return: Flattened list or generator object. :rtype: list or generator """ return _gflatten(x) if generator else _iflatten(x)
[docs] def range_str(values: iter) -> str: """ Given a list of integers, returns a terse string expressing the unique values. Example: indices = [0, 1, 2, 3, 4, 7, 8, 11, 15, 20] range_str(indices) >> '0-4, 7-8, 11, 15 & 20' :param values: An iterable of ints :return: A string of unique value ranges """ trial_str = '' values = list(set(values)) for i in range(len(values)): if i == 0: trial_str += str(values[i]) elif values[i] - (values[i - 1]) == 1: if i == len(values) - 1 or values[i + 1] - values[i] > 1: trial_str += f'-{values[i]}' else: trial_str += f', {values[i]}' # Replace final comma with an ampersand k = trial_str.rfind(',') if k > -1: trial_str = f'{trial_str[:k]} &{trial_str[k + 1:]}' return trial_str
[docs] def setup_logger(name='ibl', level=logging.NOTSET, file=None, no_color=False): """Set up a log for IBL packages. Uses date time, calling function and distinct colours for levels. Sets the name if not set already and add a stream handler. If the stream handler already exists, does not duplicate. The naming/level allows not to interfere with third-party libraries when setting level. Parameters ---------- name : str Log name, should be set to the root package name for consistent logging throughout the app. level : str, int The logging level (defaults to NOTSET, which inherits the parent log level) file : bool, str, pathlib.Path If True, a file handler is added with the default file location, otherwise a log file path may be passed. no_color : bool If true the colour log is deactivated. May be useful when directing the std out to a file. Returns ------- logging.Logger, logging.RootLogger The configured log. """ log = logging.getLogger() if not name else logging.getLogger(name) log.setLevel(level) fkwargs = {'no_color': True} if no_color else {'log_colors': LOG_COLORS} # check existence of stream handlers before adding another if not any(map(lambda x: x.name == f'{name}_auto', log.handlers)): # need to remove any previous default Stream handler configured on stderr # to not duplicate output for h in log.handlers: if isinstance(h, logging.StreamHandler) and h.stream.name == '<stderr>' and h.level == 0 and h.name is None: log.removeHandler(h) stream_handler = logging.StreamHandler(stream=sys.stdout) stream_handler.setFormatter( colorlog.ColoredFormatter('%(log_color)s' + LOG_FORMAT_STR, LOG_DATE_FORMAT, **fkwargs)) stream_handler.name = f'{name}_auto' log.addHandler(stream_handler) # add the file handler if requested, but check for duplicates if not any(map(lambda x: x.name == f'{name}_file', log.handlers)): if file is True: log_to_file(log=name) elif file: log_to_file(filename=file, log=name) return log
[docs] def log_to_file(log='ibl', filename=None): """ Save log information to a given filename in '.ibl_logs' folder (in home directory). Parameters ---------- log : str, logging.Logger The log (name or object) to add file handler to. filename : str, Pathlib.Path The name of the log file to save to. Returns ------- logging.Logger The log with the file handler attached. """ if isinstance(log, str): log = logging.getLogger(log) if filename is None: filename = Path.home().joinpath('.ibl_logs', log.name) elif not Path(filename).is_absolute(): filename = Path.home().joinpath('.ibl_logs', filename) filename.parent.mkdir(exist_ok=True) file_handler = logging.FileHandler(filename, encoding='utf-8') file_format = logging.Formatter(LOG_FORMAT_STR, LOG_DATE_FORMAT) file_handler.setFormatter(file_format) file_handler.name = f'{log.name}_file' log.addHandler(file_handler) log.info(f'File log initiated {file_handler.name}') return log
[docs] def rrmdir(folder: Path, levels: int = 0): """ Recursively remove a folder and its parents up to a defined level - if they are empty. Parameters ---------- folder : pathlib.Path The path to a folder at which to start the recursion. levels : int Recursion level, i.e. the number of parents to delete, relative to `folder`. Defaults to 0 - which has the same effect as `pathlib.Path.rmdir` except that it won't raise an OSError if the directory is not empty. Returns ------- list of pathlib.Path A list of folders that were recursively removed. Raises ------ FileNotFoundError If `folder` does not exist. PermissionError Insufficient privileges or folder in use by another process. NotADirectoryError The folder provided is most likely a file. """ try: # a sorted list of absolute nested folder paths to_remove = (folder, *folder.parents[:levels]) # py >= 3.9 except TypeError: # py <= 3.8 compatible to_remove = (folder, *[folder.parents[n] for n in range(levels)]) # filter list to those that are empty; if statement always true as rmdir returns None return [f for f in takewhile(lambda f: not any(f.iterdir()), to_remove) if not f.rmdir()]
[docs] def dir_size(directory: Union[str, Path], follow_symlinks: bool = False) -> int: """ Calculate the total size of a directory including all its subdirectories and files. Parameters ---------- directory : str | Path The path to the directory for which the size needs to be calculated. follow_symlinks : bool, optional Whether to follow symbolic links when calculating the size. Default is False. Returns ------- int The total size of the directory in bytes. """ total_bytes = 0 with scandir(directory) as it: for entry in it: if entry.is_symlink() and not follow_symlinks: continue elif entry.is_dir(): total_bytes += dir_size(entry.path, follow_symlinks) elif entry.is_file(): total_bytes += entry.stat().st_size return total_bytes
[docs] def get_mac() -> str: """ Fetch the machine's unique MAC address formatted according to IEEE 802 specifications. Returns ------- str The MAC address of the device formatted in six groups of two hexadecimal digits separated by hyphens in transmission order (e.g., 'BA-DB-AD-C0-FF-EE'). """ return uuid.getnode().to_bytes(6, 'big').hex('-').upper()
[docs] def ensure_list(value, exclude_type=(str, dict)): """Ensure input is a list. Wraps `value` in a list if not already an iterator or if it is a member of specific iterable classes. To allow users the option of passing a single value or multiple values, this function will wrap the former in a list and by default will consider str and dict instances as a single value. This function is useful because it treats tuples, lists, sets, and generators all as 'lists', but not dictionaries and strings. Parameters ---------- value : any Input to ensure list. exclude_type : tuple, optional A list of iterable classes to wrap in a list. Returns ------- Iterable Either `value` if iterable and not in `exclude_type` list, or `value` wrapped in a list. """ return [value] if isinstance(value, exclude_type) or not isinstance(value, Iterable) else value