Module brevettiai.utils.module
Expand source code
import inspect
import logging
log = logging.getLogger(__name__)
def get_parameter_type(parameter):
    if parameter.annotation is not parameter.empty:
        return parameter.annotation
    elif parameter.default is not parameter.empty:
        return type(parameter.default)
    else:
        return type(None)
class Module:
    """
    Base class for serializable modules
    """
    def get_config(self):
        signature = inspect.signature(self.__init__)
        # Extract parameters
        config = {x: getattr(self, x) for x in signature.parameters.keys() if hasattr(self, x)}
        # Map sub modules
        for k, v in config.items():
            if isinstance(v, Module):
                config[k] = v.get_config()
            else:
                if hasattr(config[k], "numpy"):
                    config[k] = config[k].numpy()
                if hasattr(config[k], "tolist"):
                    config[k] = config[k].tolist()
        return config
    @classmethod
    def from_config(cls, config):
        if config is None:
            return None
        valid_config: dict = {}
        signature = inspect.signature(cls.__init__)
        for k, v in signature.parameters.items():
            ptype = get_parameter_type(v)
            if k in config:
                if issubclass(ptype, Module):
                    valid_config[k] = ptype.from_config(config[k])
                else:
                    valid_config[k] = config[k]
            if v.kind==inspect._ParameterKind.VAR_KEYWORD:
                for k in config.keys():
                    if k not in valid_config:
                        valid_config[k] = config[k]
        if len(config) != len(valid_config):
            log.warning("Invalid config keys: " + ", ".join(list(set(config) - set(valid_config))))
        return cls(**valid_config)
    def copy(self):
        return self.from_config(self.get_config())
    @classmethod
    def __get_validators__(cls):
        # one or more validators may be yielded which will be called in the
        # order to validate the input, each validator will receive as an input
        # the value returned from the previous validator
        yield cls.validator
    @classmethod
    def __modify_schema__(cls, field_schema):
        # __modify_schema__ should mutate the dict it receives in place,
        # the returned value will be ignored
        field_schema.update(
            type=cls.__name__
        )
    @classmethod
    def validator(cls, x):
        if isinstance(x, cls):
            return x
        return cls.from_config(x)
Functions
def get_parameter_type(parameter)- 
Expand source code
def get_parameter_type(parameter): if parameter.annotation is not parameter.empty: return parameter.annotation elif parameter.default is not parameter.empty: return type(parameter.default) else: return type(None) 
Classes
class Module- 
Base class for serializable modules
Expand source code
class Module: """ Base class for serializable modules """ def get_config(self): signature = inspect.signature(self.__init__) # Extract parameters config = {x: getattr(self, x) for x in signature.parameters.keys() if hasattr(self, x)} # Map sub modules for k, v in config.items(): if isinstance(v, Module): config[k] = v.get_config() else: if hasattr(config[k], "numpy"): config[k] = config[k].numpy() if hasattr(config[k], "tolist"): config[k] = config[k].tolist() return config @classmethod def from_config(cls, config): if config is None: return None valid_config: dict = {} signature = inspect.signature(cls.__init__) for k, v in signature.parameters.items(): ptype = get_parameter_type(v) if k in config: if issubclass(ptype, Module): valid_config[k] = ptype.from_config(config[k]) else: valid_config[k] = config[k] if v.kind==inspect._ParameterKind.VAR_KEYWORD: for k in config.keys(): if k not in valid_config: valid_config[k] = config[k] if len(config) != len(valid_config): log.warning("Invalid config keys: " + ", ".join(list(set(config) - set(valid_config)))) return cls(**valid_config) def copy(self): return self.from_config(self.get_config()) @classmethod def __get_validators__(cls): # one or more validators may be yielded which will be called in the # order to validate the input, each validator will receive as an input # the value returned from the previous validator yield cls.validator @classmethod def __modify_schema__(cls, field_schema): # __modify_schema__ should mutate the dict it receives in place, # the returned value will be ignored field_schema.update( type=cls.__name__ ) @classmethod def validator(cls, x): if isinstance(x, cls): return x return cls.from_config(x)Subclasses
Static methods
def from_config(config)- 
Expand source code
@classmethod def from_config(cls, config): if config is None: return None valid_config: dict = {} signature = inspect.signature(cls.__init__) for k, v in signature.parameters.items(): ptype = get_parameter_type(v) if k in config: if issubclass(ptype, Module): valid_config[k] = ptype.from_config(config[k]) else: valid_config[k] = config[k] if v.kind==inspect._ParameterKind.VAR_KEYWORD: for k in config.keys(): if k not in valid_config: valid_config[k] = config[k] if len(config) != len(valid_config): log.warning("Invalid config keys: " + ", ".join(list(set(config) - set(valid_config)))) return cls(**valid_config) def validator(x)- 
Expand source code
@classmethod def validator(cls, x): if isinstance(x, cls): return x return cls.from_config(x) 
Methods
def copy(self)- 
Expand source code
def copy(self): return self.from_config(self.get_config()) def get_config(self)- 
Expand source code
def get_config(self): signature = inspect.signature(self.__init__) # Extract parameters config = {x: getattr(self, x) for x in signature.parameters.keys() if hasattr(self, x)} # Map sub modules for k, v in config.items(): if isinstance(v, Module): config[k] = v.get_config() else: if hasattr(config[k], "numpy"): config[k] = config[k].numpy() if hasattr(config[k], "tolist"): config[k] = config[k].tolist() return config