Module brevettiai.utils.module

Expand source code
import inspect
import logging


log = logging.getLogger(__name__)


def get_parameter_type(parameter):
    if parameter.annotation is not parameter.empty:
        return parameter.annotation
    elif parameter.default is not parameter.empty:
        return type(parameter.default)
    else:
        return type(None)


class Module:
    """
    Base class for serializable modules
    """
    def get_config(self):
        signature = inspect.signature(self.__init__)
        # Extract parameters
        config = {x: getattr(self, x) for x in signature.parameters.keys() if hasattr(self, x)}

        # Map sub modules
        for k, v in config.items():
            if isinstance(v, Module):
                config[k] = v.get_config()
            else:
                if hasattr(config[k], "numpy"):
                    config[k] = config[k].numpy()
                if hasattr(config[k], "tolist"):
                    config[k] = config[k].tolist()

        return config

    @classmethod
    def from_config(cls, config):
        if config is None:
            return None
        valid_config: dict = {}
        signature = inspect.signature(cls.__init__)
        for k, v in signature.parameters.items():
            ptype = get_parameter_type(v)
            if k in config:
                if issubclass(ptype, Module):
                    valid_config[k] = ptype.from_config(config[k])
                else:
                    valid_config[k] = config[k]
            if v.kind==inspect._ParameterKind.VAR_KEYWORD:
                for k in config.keys():
                    if k not in valid_config:
                        valid_config[k] = config[k]

        if len(config) != len(valid_config):
            log.warning("Invalid config keys: " + ", ".join(list(set(config) - set(valid_config))))
        return cls(**valid_config)

    def copy(self):
        return self.from_config(self.get_config())

    @classmethod
    def __get_validators__(cls):
        # one or more validators may be yielded which will be called in the
        # order to validate the input, each validator will receive as an input
        # the value returned from the previous validator
        yield cls.validator

    @classmethod
    def __modify_schema__(cls, field_schema):
        # __modify_schema__ should mutate the dict it receives in place,
        # the returned value will be ignored
        field_schema.update(
            type=cls.__name__
        )

    @classmethod
    def validator(cls, x):
        if isinstance(x, cls):
            return x
        return cls.from_config(x)

Functions

def get_parameter_type(parameter)
Expand source code
def get_parameter_type(parameter):
    if parameter.annotation is not parameter.empty:
        return parameter.annotation
    elif parameter.default is not parameter.empty:
        return type(parameter.default)
    else:
        return type(None)

Classes

class Module

Base class for serializable modules

Expand source code
class Module:
    """
    Base class for serializable modules
    """
    def get_config(self):
        signature = inspect.signature(self.__init__)
        # Extract parameters
        config = {x: getattr(self, x) for x in signature.parameters.keys() if hasattr(self, x)}

        # Map sub modules
        for k, v in config.items():
            if isinstance(v, Module):
                config[k] = v.get_config()
            else:
                if hasattr(config[k], "numpy"):
                    config[k] = config[k].numpy()
                if hasattr(config[k], "tolist"):
                    config[k] = config[k].tolist()

        return config

    @classmethod
    def from_config(cls, config):
        if config is None:
            return None
        valid_config: dict = {}
        signature = inspect.signature(cls.__init__)
        for k, v in signature.parameters.items():
            ptype = get_parameter_type(v)
            if k in config:
                if issubclass(ptype, Module):
                    valid_config[k] = ptype.from_config(config[k])
                else:
                    valid_config[k] = config[k]
            if v.kind==inspect._ParameterKind.VAR_KEYWORD:
                for k in config.keys():
                    if k not in valid_config:
                        valid_config[k] = config[k]

        if len(config) != len(valid_config):
            log.warning("Invalid config keys: " + ", ".join(list(set(config) - set(valid_config))))
        return cls(**valid_config)

    def copy(self):
        return self.from_config(self.get_config())

    @classmethod
    def __get_validators__(cls):
        # one or more validators may be yielded which will be called in the
        # order to validate the input, each validator will receive as an input
        # the value returned from the previous validator
        yield cls.validator

    @classmethod
    def __modify_schema__(cls, field_schema):
        # __modify_schema__ should mutate the dict it receives in place,
        # the returned value will be ignored
        field_schema.update(
            type=cls.__name__
        )

    @classmethod
    def validator(cls, x):
        if isinstance(x, cls):
            return x
        return cls.from_config(x)

Subclasses

Static methods

def from_config(config)
Expand source code
@classmethod
def from_config(cls, config):
    if config is None:
        return None
    valid_config: dict = {}
    signature = inspect.signature(cls.__init__)
    for k, v in signature.parameters.items():
        ptype = get_parameter_type(v)
        if k in config:
            if issubclass(ptype, Module):
                valid_config[k] = ptype.from_config(config[k])
            else:
                valid_config[k] = config[k]
        if v.kind==inspect._ParameterKind.VAR_KEYWORD:
            for k in config.keys():
                if k not in valid_config:
                    valid_config[k] = config[k]

    if len(config) != len(valid_config):
        log.warning("Invalid config keys: " + ", ".join(list(set(config) - set(valid_config))))
    return cls(**valid_config)
def validator(x)
Expand source code
@classmethod
def validator(cls, x):
    if isinstance(x, cls):
        return x
    return cls.from_config(x)

Methods

def copy(self)
Expand source code
def copy(self):
    return self.from_config(self.get_config())
def get_config(self)
Expand source code
def get_config(self):
    signature = inspect.signature(self.__init__)
    # Extract parameters
    config = {x: getattr(self, x) for x in signature.parameters.keys() if hasattr(self, x)}

    # Map sub modules
    for k, v in config.items():
        if isinstance(v, Module):
            config[k] = v.get_config()
        else:
            if hasattr(config[k], "numpy"):
                config[k] = config[k].numpy()
            if hasattr(config[k], "tolist"):
                config[k] = config[k].tolist()

    return config