Module brevettiai.model.losses

Expand source code
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.losses import LossFunctionWrapper
from tensorflow.python.keras.utils import losses_utils
from pydantic import BaseModel
from typing import List, Union
from typing_extensions import Literal
from pydantic import Field


def weighted_loss(y_true, y_pred, baseloss, sample_weights, sample_weights_bias, output_weights, zero_channels_weighted=True, **kwargs):
    baseloss_fn = tf.keras.losses.get(baseloss)
    loss = baseloss_fn(y_true[..., None], y_pred[..., None], **kwargs)

    non_zeros_samples = y_true
    bias = tf.constant(0, dtype=y_true.dtype)

    if sample_weights is not None:
        non_zeros_samples = tf.tensordot(y_true, tf.cast(sample_weights, y_true.dtype), axes=[-1, 0])
        bias = tf.cast(sample_weights_bias, y_true.dtype)

    if not tf.reduce_all(zero_channels_weighted):
        nonzero_channels = tf.reduce_any(non_zeros_samples != 0,
                                         tf.range(1, tf.rank(non_zeros_samples) - 1),
                                         keepdims=True)

        if not isinstance(zero_channels_weighted, bool):
            # make broadcastable
            reshaper = tf.cast(tf.pad(tf.ones(tf.rank(y_true)-1), [[0, 1]], constant_values=-1), tf.int32)
            zero_channels_weighted = tf.reshape(tf.constant(zero_channels_weighted), reshaper)

        weight = nonzero_channels | zero_channels_weighted
        loss = loss * tf.cast(weight, loss.dtype)
        bias = tf.cast(bias, tf.float32) * tf.cast(weight, tf.float32)

    if sample_weights is not None:
        ww = non_zeros_samples + bias
        ww = tf.clip_by_value(ww, 0, np.inf)
        loss = loss * ww

    if output_weights is not None:
        loss = tf.tensordot(loss, tf.cast(output_weights, loss.dtype), axes=1)

    return loss


class WeightedLoss(LossFunctionWrapper):
    def __init__(self, **kwargs):
        super().__init__(weighted_loss, **kwargs)

    @property
    def fn_kwargs(self):
        return self._fn_kwargs


class WeightedLossFactory(BaseModel):
    sample_weights: List[List[float]] = Field(default=None)
    sample_weights_bias: List[float] = Field(default=None)
    output_weights: List[float] = Field(default=None)
    zero_channels_weighted: Union[bool, List[bool]] = Field(default=True)
    label_smoothing: float = 0.0
    baseloss: str = "binary_crossentropy"

    def get_loss(self, from_logits: bool = False,
                 reduction: Literal['AUTO', 'NONE', 'SUM', 'SUM_OVER_BATCH_SIZE'] = "AUTO",
                 name: str = 'weighted_loss'):

        return WeightedLoss(
            name=name,
            reduction=getattr(losses_utils.ReductionV2, reduction),
            from_logits=from_logits,
            **self.dict())

Functions

def weighted_loss(y_true, y_pred, baseloss, sample_weights, sample_weights_bias, output_weights, zero_channels_weighted=True, **kwargs)
Expand source code
def weighted_loss(y_true, y_pred, baseloss, sample_weights, sample_weights_bias, output_weights, zero_channels_weighted=True, **kwargs):
    baseloss_fn = tf.keras.losses.get(baseloss)
    loss = baseloss_fn(y_true[..., None], y_pred[..., None], **kwargs)

    non_zeros_samples = y_true
    bias = tf.constant(0, dtype=y_true.dtype)

    if sample_weights is not None:
        non_zeros_samples = tf.tensordot(y_true, tf.cast(sample_weights, y_true.dtype), axes=[-1, 0])
        bias = tf.cast(sample_weights_bias, y_true.dtype)

    if not tf.reduce_all(zero_channels_weighted):
        nonzero_channels = tf.reduce_any(non_zeros_samples != 0,
                                         tf.range(1, tf.rank(non_zeros_samples) - 1),
                                         keepdims=True)

        if not isinstance(zero_channels_weighted, bool):
            # make broadcastable
            reshaper = tf.cast(tf.pad(tf.ones(tf.rank(y_true)-1), [[0, 1]], constant_values=-1), tf.int32)
            zero_channels_weighted = tf.reshape(tf.constant(zero_channels_weighted), reshaper)

        weight = nonzero_channels | zero_channels_weighted
        loss = loss * tf.cast(weight, loss.dtype)
        bias = tf.cast(bias, tf.float32) * tf.cast(weight, tf.float32)

    if sample_weights is not None:
        ww = non_zeros_samples + bias
        ww = tf.clip_by_value(ww, 0, np.inf)
        loss = loss * ww

    if output_weights is not None:
        loss = tf.tensordot(loss, tf.cast(output_weights, loss.dtype), axes=1)

    return loss

Classes

class WeightedLoss (**kwargs)

Wraps a loss function in the Loss class.

Initializes LossFunctionWrapper class.

Args

fn
The loss function to wrap, with signature fn(y_true, y_pred, **kwargs).
reduction
Type of tf.keras.losses.Reduction to apply to loss. Default value is AUTO. AUTO indicates that the reduction option will be determined by the usage context. For almost all cases this defaults to SUM_OVER_BATCH_SIZE. When used with tf.distribute.Strategy, outside of built-in training loops such as tf.keras compile and fit, using AUTO or SUM_OVER_BATCH_SIZE will raise an error. Please see this custom training tutorial for more details.
name
Optional name for the instance.
**kwargs
The keyword arguments that are passed on to fn.
Expand source code
class WeightedLoss(LossFunctionWrapper):
    def __init__(self, **kwargs):
        super().__init__(weighted_loss, **kwargs)

    @property
    def fn_kwargs(self):
        return self._fn_kwargs

Ancestors

  • tensorflow.python.keras.losses.LossFunctionWrapper
  • tensorflow.python.keras.losses.Loss

Instance variables

var fn_kwargs
Expand source code
@property
def fn_kwargs(self):
    return self._fn_kwargs
class WeightedLossFactory (**data: Any)

Create a new model by parsing and validating input data from keyword arguments.

Raises ValidationError if the input data cannot be parsed to form a valid model.

Expand source code
class WeightedLossFactory(BaseModel):
    sample_weights: List[List[float]] = Field(default=None)
    sample_weights_bias: List[float] = Field(default=None)
    output_weights: List[float] = Field(default=None)
    zero_channels_weighted: Union[bool, List[bool]] = Field(default=True)
    label_smoothing: float = 0.0
    baseloss: str = "binary_crossentropy"

    def get_loss(self, from_logits: bool = False,
                 reduction: Literal['AUTO', 'NONE', 'SUM', 'SUM_OVER_BATCH_SIZE'] = "AUTO",
                 name: str = 'weighted_loss'):

        return WeightedLoss(
            name=name,
            reduction=getattr(losses_utils.ReductionV2, reduction),
            from_logits=from_logits,
            **self.dict())

Ancestors

  • pydantic.main.BaseModel
  • pydantic.utils.Representation

Class variables

var baseloss : str
var label_smoothing : float
var output_weights : List[float]
var sample_weights : List[List[float]]
var sample_weights_bias : List[float]
var zero_channels_weighted : Union[bool, List[bool]]

Methods

def get_loss(self, from_logits: bool = False, reduction: typing_extensions.Literal['AUTO', 'NONE', 'SUM', 'SUM_OVER_BATCH_SIZE'] = 'AUTO', name: str = 'weighted_loss')
Expand source code
def get_loss(self, from_logits: bool = False,
             reduction: Literal['AUTO', 'NONE', 'SUM', 'SUM_OVER_BATCH_SIZE'] = "AUTO",
             name: str = 'weighted_loss'):

    return WeightedLoss(
        name=name,
        reduction=getattr(losses_utils.ReductionV2, reduction),
        from_logits=from_logits,
        **self.dict())