view piecrust/configuration.py @ 1167:c0c00dc1eac7

core: Remove deprecation warning about collections.abc.
author Ludovic Chabant <ludovic@chabant.com>
date Fri, 04 Oct 2019 10:07:38 -0700
parents 7e51d14097cb
children
line wrap: on
line source

import re
import logging
import collections
import collections.abc
import yaml
from yaml.constructor import ConstructorError
try:
    from yaml import CSafeLoader as SafeLoader
except ImportError:
    from yaml import SafeLoader


logger = logging.getLogger(__name__)

default_allowed_types = (dict, list, tuple, float, int, bool, str)


MERGE_NEW_VALUES = 0
MERGE_OVERWRITE_VALUES = 1
MERGE_PREPEND_LISTS = 2
MERGE_APPEND_LISTS = 4
MERGE_ALL = MERGE_OVERWRITE_VALUES | MERGE_PREPEND_LISTS


class ConfigurationError(Exception):
    pass


class Configuration(collections.abc.MutableMapping):
    def __init__(self, values=None, validate=True):
        if values is not None:
            self.setAll(values, validate=validate)
        else:
            self._values = None

    def __getitem__(self, key):
        self._ensureLoaded()
        try:
            return get_dict_value(self._values, key)
        except KeyError:
            raise KeyError("No such item: %s" % key)

    def __setitem__(self, key, value):
        self._ensureLoaded()
        value = self._validateValue(key, value)
        set_dict_value(self._values, key, value)

    def __delitem__(self, key):
        raise NotImplementedError()

    def __iter__(self):
        self._ensureLoaded()
        return iter(self._values)

    def __len__(self):
        self._ensureLoaded()
        return len(self._values)

    def has(self, key):
        return key in self

    def set(self, key, value):
        self[key] = value

    def setAll(self, values, validate=False):
        if validate:
            values = self._validateAll(values)
        self._values = values

    def getAll(self):
        self._ensureLoaded()
        return self._values

    def merge(self, other, mode=MERGE_ALL):
        self._ensureLoaded()

        if isinstance(other, dict):
            other_values = other
        elif isinstance(other, Configuration):
            other_values = other._values
        else:
            raise Exception(
                "Unsupported value type to merge: %s" % type(other))

        merge_dicts(self._values, other_values,
                    mode=mode,
                    validator=self._validateValue)

    def validateTypes(self, allowed_types=default_allowed_types):
        self._validateDictTypesRecursive(self._values, allowed_types)

    def _validateDictTypesRecursive(self, d, allowed_types):
        for k, v in d.items():
            if not isinstance(k, str):
                raise ConfigurationError("Key '%s' is not a string." % k)
            self._validateTypeRecursive(v, allowed_types)

    def _validateListTypesRecursive(self, l, allowed_types):
        for v in l:
            self._validateTypeRecursive(v, allowed_types)

    def _validateTypeRecursive(self, v, allowed_types):
        if v is None:
            return
        if not isinstance(v, allowed_types):
            raise ConfigurationError(
                "Value '%s' is of forbidden type: %s" % (v, type(v)))
        if isinstance(v, dict):
            self._validateDictTypesRecursive(v, allowed_types)
        elif isinstance(v, list):
            self._validateListTypesRecursive(v, allowed_types)

    def _ensureLoaded(self):
        if self._values is None:
            self._load()

    def _load(self):
        self._values = self._validateAll({})

    def _validateAll(self, values):
        return values

    def _validateValue(self, key_path, value):
        return value


def get_dict_value(d, key):
    bits = key.split('/')
    cur = d
    for b in bits:
        cur = cur[b]
    return cur


def get_dict_values(*args):
    for d, key in args:
        try:
            return get_dict_value(d, key)
        except KeyError:
            continue
    raise KeyError()


def try_get_dict_value(d, key, *, default=None):
    try:
        return get_dict_value(d, key)
    except KeyError:
        return default


def try_get_dict_values(*args, default=None):
    for d, key in args:
        try:
            return get_dict_value(d, key)
        except KeyError:
            continue
    return default


def set_dict_value(d, key, value):
    bits = key.split('/')
    bitslen = len(bits)
    cur = d
    for i, b in enumerate(bits):
        if i == bitslen - 1:
            cur[b] = value
        else:
            if b not in cur:
                cur[b] = {}
            cur = cur[b]


def merge_dicts(source, merging, *args,
                validator=None, mode=MERGE_ALL):
    _recurse_merge_dicts(source, merging, None, validator, mode)
    for other in args:
        _recurse_merge_dicts(source, other, None, validator, mode)
    return source


def _recurse_merge_dicts(local_cur, incoming_cur, parent_path,
                         validator, mode):
    for k, v in incoming_cur.items():
        key_path = k
        if parent_path is not None:
            key_path = parent_path + '/' + k

        local_v = local_cur.get(k)
        if local_v is not None:
            if isinstance(v, dict) and isinstance(local_v, dict):
                _recurse_merge_dicts(local_v, v, key_path,
                                     validator, mode)
            elif isinstance(v, list) and isinstance(local_v, list):
                if mode & MERGE_PREPEND_LISTS:
                    local_cur[k] = v + local_v
                elif mode & MERGE_APPEND_LISTS:
                    local_cur[k] = local_v + v
            else:
                if mode & MERGE_OVERWRITE_VALUES:
                    if validator is not None:
                        v = validator(key_path, v)
                    local_cur[k] = v
        else:
            if ((mode & (MERGE_PREPEND_LISTS | MERGE_APPEND_LISTS)) or
                    not isinstance(v, list)):
                if validator is not None:
                    v = validator(key_path, v)
                local_cur[k] = v


def visit_dict(subject, visitor):
    _recurse_visit_dict(subject, None, visitor)


def _recurse_visit_dict(cur, parent_path, visitor):
    for k, v in cur.items():
        key_path = k
        if parent_path is not None:
            key_path = parent_path + '/' + k

        visitor(key_path, v, cur, k)
        if isinstance(v, dict):
            _recurse_visit_dict(v, key_path, visitor)


header_regex = re.compile(
    r'(---\s*\n)(?P<header>(.*\n)*?)^(---\s*\n)', re.MULTILINE)


def parse_config_header(text):
    m = header_regex.match(text)
    if m is not None:
        header = str(m.group('header'))
        config = yaml.load(header, Loader=ConfigurationLoader)
        offset = m.end()
    else:
        config = {}
        offset = 0
    return config, offset


class ConfigurationLoader(SafeLoader):
    """ A YAML loader that loads mappings into ordered dictionaries,
        and supports sexagesimal notations for timestamps.
    """
    def __init__(self, *args, **kwargs):
        super(ConfigurationLoader, self).__init__(*args, **kwargs)

        self.add_constructor('tag:yaml.org,2002:map',
                             type(self).construct_yaml_map)
        self.add_constructor('tag:yaml.org,2002:omap',
                             type(self).construct_yaml_map)
        self.add_constructor('tag:yaml.org,2002:sexagesimal',
                             type(self).construct_yaml_time)

    def construct_yaml_map(self, node):
        data = collections.OrderedDict()
        yield data
        value = self.construct_mapping(node)
        data.update(value)

    def construct_mapping(self, node, deep=False):
        if not isinstance(node, yaml.MappingNode):
            raise ConstructorError(
                None, None,
                "expected a mapping node, but found %s" % node.id,
                node.start_mark)
        mapping = collections.OrderedDict()
        for key_node, value_node in node.value:
            key = self.construct_object(key_node, deep=deep)
            if not isinstance(key, collections.abc.Hashable):
                raise ConstructorError(
                    "while constructing a mapping", node.start_mark,
                    "found unhashable key", key_node.start_mark)
            value = self.construct_object(value_node, deep=deep)
            mapping[key] = value
        return mapping

    time_regexp = re.compile(
        r'''^(?P<hour>[0-9][0-9]?)
                :(?P<minute>[0-9][0-9])
                (:(?P<second>[0-9][0-9])
                (\.(?P<fraction>[0-9]+))?)?$''', re.X)

    def construct_yaml_time(self, node):
        self.construct_scalar(node)
        match = self.time_regexp.match(node.value)
        values = match.groupdict()
        hour = int(values['hour'])
        minute = int(values['minute'])
        second = 0
        if values['second']:
            second = int(values['second'])
        usec = 0
        if values['fraction']:
            usec = float('0.' + values['fraction'])
        return second + minute * 60 + hour * 60 * 60 + usec


ConfigurationLoader.add_implicit_resolver(
    'tag:yaml.org,2002:sexagesimal',
    re.compile(r'''^[0-9][0-9]?:[0-9][0-9]
                    (:[0-9][0-9](\.[0-9]+)?)?$''', re.X),
    list('0123456789'))


# We need to add our `sexagesimal` resolver before the `int` one, which
# already supports sexagesimal notation in YAML 1.1 (but not 1.2). However,
# because we know we pretty much always want it for representing time, we
# need a simple `12:30` to mean 45000, not 750. So that's why we override
# the default behaviour.
for ch in list('0123456789'):
    ch_resolvers = ConfigurationLoader.yaml_implicit_resolvers[ch]
    ch_resolvers.insert(0, ch_resolvers.pop())


class ConfigurationDumper(yaml.SafeDumper):
    def represent_ordered_dict(self, data):
        # Not a typo: we're using `map` and not `omap` because we don't want
        # ugly type tags printed in the generated YAML markup, and because
        # we always load maps into `OrderedDicts` anyway.
        return self.represent_mapping('tag:yaml.org,2002:map', data)


ConfigurationDumper.add_representer(collections.OrderedDict,
                                    ConfigurationDumper.represent_ordered_dict)