Module brevettiai.model.factory.mobilenetv2_backbone

Expand source code
from functools import partial
from typing import List

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.applications import MobileNetV2
from brevettiai.model.factory import ModelFactory


def remap_backbone(bn_momentum, default_regularizer, exchange_padding_on):
    def _remap_backbone(layer):
        if isinstance(layer, tf.keras.Model):
            return tf.keras.models.clone_model(layer, input_tensors=None, clone_function=_remap_backbone)
        if isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.momentum = bn_momentum
        elif isinstance(layer, layers.Conv2D):
            layer.kernel_regularizer = default_regularizer
            if layer.name in exchange_padding_on:
                # Exchange possible asymmetric zero_padding on layers
                layer.padding = "valid"
                return tf.keras.Sequential([
                    layers.ZeroPadding2D(1),
                    # layers.Lambda(zero_pad_to_even, name="zero_pad_to_even"),
                    layer.__class__.from_config(layer.get_config())]
                )
        return layer.__class__.from_config(layer.get_config())

    return _remap_backbone


class MobileNetV2SegmentationBackbone(ModelFactory):
    output_layers: List[str]
    weights: str = 'imagenet'
    alpha: float = 1
    bn_momentum: float = 0.9
    l1_regularization: float = 0
    l2_regularization: float = 0

    @staticmethod
    def custom_objects():
        return {
            "relu6": tf.nn.relu6
        }

    def build(self, input_shape, *args, **kwargs):
        if not self.weights is None and input_shape[-1] is not 3:
            bb_source = MobileNetV2(input_shape=(*input_shape[:-1], 3),
                                    include_top=False, weights=self.weights, alpha=self.alpha)
            backbone = MobileNetV2(input_shape=input_shape, include_top=False,
                                   weights=None, alpha=self.alpha)

            # Exchange layer 1 weights
            w = bb_source.get_weights()
            w[0] = w[0].sum(axis=2, keepdims=True)
            w[0] = np.tile(w[0], (1, 1, input_shape[-1], 1)) * np.random.randn(*w[0].shape) * 0.05
            backbone.set_weights(w)

        else:
            backbone = MobileNetV2(input_shape=input_shape,
                                   include_top=False,
                                   weights=self.weights,
                                   alpha=self.alpha)

        backbone = tf.keras.Model(backbone.input, [backbone.get_layer(l).output for l in self.output_layers],
                                  name=f"MobilenetV2_a{self.alpha}")

        if self.l1_regularization != 0 and self.l2_regularization != 0:
            default_regularizer = tf.keras.regularizers.l1l2(l1=self.l1_regularization, l2=self.l2_regularization)
        else:
            default_regularizer = None

        map_backbone = remap_backbone(bn_momentum=self.bn_momentum,
                                      default_regularizer=default_regularizer,
                                      exchange_padding_on={"Conv1"})

        backbone_clone = tf.keras.models.clone_model(backbone, clone_function=map_backbone)
        backbone_clone.set_weights(backbone.get_weights())
        return backbone_clone


lightning_segmentation_backbone = partial(
    MobileNetV2SegmentationBackbone,
    output_layers=['block_2_add', 'block_5_add', 'block_9_add'],
    alpha=0.35)

thunder_segmentation_backbone = partial(
    MobileNetV2SegmentationBackbone,
    output_layers=['expanded_conv_project', 'block_2_add', 'block_5_add', 'block_9_add', 'block_15_add'],
    alpha=0.35)

Functions

def remap_backbone(bn_momentum, default_regularizer, exchange_padding_on)
Expand source code
def remap_backbone(bn_momentum, default_regularizer, exchange_padding_on):
    def _remap_backbone(layer):
        if isinstance(layer, tf.keras.Model):
            return tf.keras.models.clone_model(layer, input_tensors=None, clone_function=_remap_backbone)
        if isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.momentum = bn_momentum
        elif isinstance(layer, layers.Conv2D):
            layer.kernel_regularizer = default_regularizer
            if layer.name in exchange_padding_on:
                # Exchange possible asymmetric zero_padding on layers
                layer.padding = "valid"
                return tf.keras.Sequential([
                    layers.ZeroPadding2D(1),
                    # layers.Lambda(zero_pad_to_even, name="zero_pad_to_even"),
                    layer.__class__.from_config(layer.get_config())]
                )
        return layer.__class__.from_config(layer.get_config())

    return _remap_backbone

Classes

class MobileNetV2SegmentationBackbone (**data: Any)

Abstract model factory class

Create a new model by parsing and validating input data from keyword arguments.

Raises ValidationError if the input data cannot be parsed to form a valid model.

Expand source code
class MobileNetV2SegmentationBackbone(ModelFactory):
    output_layers: List[str]
    weights: str = 'imagenet'
    alpha: float = 1
    bn_momentum: float = 0.9
    l1_regularization: float = 0
    l2_regularization: float = 0

    @staticmethod
    def custom_objects():
        return {
            "relu6": tf.nn.relu6
        }

    def build(self, input_shape, *args, **kwargs):
        if not self.weights is None and input_shape[-1] is not 3:
            bb_source = MobileNetV2(input_shape=(*input_shape[:-1], 3),
                                    include_top=False, weights=self.weights, alpha=self.alpha)
            backbone = MobileNetV2(input_shape=input_shape, include_top=False,
                                   weights=None, alpha=self.alpha)

            # Exchange layer 1 weights
            w = bb_source.get_weights()
            w[0] = w[0].sum(axis=2, keepdims=True)
            w[0] = np.tile(w[0], (1, 1, input_shape[-1], 1)) * np.random.randn(*w[0].shape) * 0.05
            backbone.set_weights(w)

        else:
            backbone = MobileNetV2(input_shape=input_shape,
                                   include_top=False,
                                   weights=self.weights,
                                   alpha=self.alpha)

        backbone = tf.keras.Model(backbone.input, [backbone.get_layer(l).output for l in self.output_layers],
                                  name=f"MobilenetV2_a{self.alpha}")

        if self.l1_regularization != 0 and self.l2_regularization != 0:
            default_regularizer = tf.keras.regularizers.l1l2(l1=self.l1_regularization, l2=self.l2_regularization)
        else:
            default_regularizer = None

        map_backbone = remap_backbone(bn_momentum=self.bn_momentum,
                                      default_regularizer=default_regularizer,
                                      exchange_padding_on={"Conv1"})

        backbone_clone = tf.keras.models.clone_model(backbone, clone_function=map_backbone)
        backbone_clone.set_weights(backbone.get_weights())
        return backbone_clone

Ancestors

  • ModelFactory
  • abc.ABC
  • pydantic.main.BaseModel
  • pydantic.utils.Representation

Class variables

var alpha : float
var bn_momentum : float
var l1_regularization : float
var l2_regularization : float
var output_layers : List[str]
var weights : str

Inherited members