Module brevettiai.model.losses
Expand source code
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.losses import LossFunctionWrapper
from tensorflow.python.keras.utils import losses_utils
from pydantic import BaseModel
from typing import List, Union
from typing_extensions import Literal
from pydantic import Field
def weighted_loss(y_true, y_pred, baseloss, sample_weights, sample_weights_bias, output_weights, zero_channels_weighted=True, **kwargs):
baseloss_fn = tf.keras.losses.get(baseloss)
loss = baseloss_fn(y_true[..., None], y_pred[..., None], **kwargs)
non_zeros_samples = y_true
bias = tf.constant(0, dtype=y_true.dtype)
if sample_weights is not None:
non_zeros_samples = tf.tensordot(y_true, tf.cast(sample_weights, y_true.dtype), axes=[-1, 0])
bias = tf.cast(sample_weights_bias, y_true.dtype)
if not tf.reduce_all(zero_channels_weighted):
nonzero_channels = tf.reduce_any(non_zeros_samples != 0,
tf.range(1, tf.rank(non_zeros_samples) - 1),
keepdims=True)
if not isinstance(zero_channels_weighted, bool):
# make broadcastable
reshaper = tf.cast(tf.pad(tf.ones(tf.rank(y_true)-1), [[0, 1]], constant_values=-1), tf.int32)
zero_channels_weighted = tf.reshape(tf.constant(zero_channels_weighted), reshaper)
weight = nonzero_channels | zero_channels_weighted
loss = loss * tf.cast(weight, loss.dtype)
bias = tf.cast(bias, tf.float32) * tf.cast(weight, tf.float32)
if sample_weights is not None:
ww = non_zeros_samples + bias
ww = tf.clip_by_value(ww, 0, np.inf)
loss = loss * ww
if output_weights is not None:
loss = tf.tensordot(loss, tf.cast(output_weights, loss.dtype), axes=1)
return loss
class WeightedLoss(LossFunctionWrapper):
def __init__(self, **kwargs):
super().__init__(weighted_loss, **kwargs)
@property
def fn_kwargs(self):
return self._fn_kwargs
class WeightedLossFactory(BaseModel):
sample_weights: List[List[float]] = Field(default=None)
sample_weights_bias: List[float] = Field(default=None)
output_weights: List[float] = Field(default=None)
zero_channels_weighted: Union[bool, List[bool]] = Field(default=True)
label_smoothing: float = 0.0
baseloss: str = "binary_crossentropy"
def get_loss(self, from_logits: bool = False,
reduction: Literal['AUTO', 'NONE', 'SUM', 'SUM_OVER_BATCH_SIZE'] = "AUTO",
name: str = 'weighted_loss'):
return WeightedLoss(
name=name,
reduction=getattr(losses_utils.ReductionV2, reduction),
from_logits=from_logits,
**self.dict())
Functions
def weighted_loss(y_true, y_pred, baseloss, sample_weights, sample_weights_bias, output_weights, zero_channels_weighted=True, **kwargs)-
Expand source code
def weighted_loss(y_true, y_pred, baseloss, sample_weights, sample_weights_bias, output_weights, zero_channels_weighted=True, **kwargs): baseloss_fn = tf.keras.losses.get(baseloss) loss = baseloss_fn(y_true[..., None], y_pred[..., None], **kwargs) non_zeros_samples = y_true bias = tf.constant(0, dtype=y_true.dtype) if sample_weights is not None: non_zeros_samples = tf.tensordot(y_true, tf.cast(sample_weights, y_true.dtype), axes=[-1, 0]) bias = tf.cast(sample_weights_bias, y_true.dtype) if not tf.reduce_all(zero_channels_weighted): nonzero_channels = tf.reduce_any(non_zeros_samples != 0, tf.range(1, tf.rank(non_zeros_samples) - 1), keepdims=True) if not isinstance(zero_channels_weighted, bool): # make broadcastable reshaper = tf.cast(tf.pad(tf.ones(tf.rank(y_true)-1), [[0, 1]], constant_values=-1), tf.int32) zero_channels_weighted = tf.reshape(tf.constant(zero_channels_weighted), reshaper) weight = nonzero_channels | zero_channels_weighted loss = loss * tf.cast(weight, loss.dtype) bias = tf.cast(bias, tf.float32) * tf.cast(weight, tf.float32) if sample_weights is not None: ww = non_zeros_samples + bias ww = tf.clip_by_value(ww, 0, np.inf) loss = loss * ww if output_weights is not None: loss = tf.tensordot(loss, tf.cast(output_weights, loss.dtype), axes=1) return loss
Classes
class WeightedLoss (**kwargs)-
Wraps a loss function in the
Lossclass.Initializes
LossFunctionWrapperclass.Args
fn- The loss function to wrap, with signature
fn(y_true, y_pred, **kwargs). reduction- Type of
tf.keras.losses.Reductionto apply to loss. Default value isAUTO.AUTOindicates that the reduction option will be determined by the usage context. For almost all cases this defaults toSUM_OVER_BATCH_SIZE. When used withtf.distribute.Strategy, outside of built-in training loops such astf.kerascompileandfit, usingAUTOorSUM_OVER_BATCH_SIZEwill raise an error. Please see this custom training tutorial for more details. name- Optional name for the instance.
**kwargs- The keyword arguments that are passed on to
fn.
Expand source code
class WeightedLoss(LossFunctionWrapper): def __init__(self, **kwargs): super().__init__(weighted_loss, **kwargs) @property def fn_kwargs(self): return self._fn_kwargsAncestors
- tensorflow.python.keras.losses.LossFunctionWrapper
- tensorflow.python.keras.losses.Loss
Instance variables
var fn_kwargs-
Expand source code
@property def fn_kwargs(self): return self._fn_kwargs
class WeightedLossFactory (**data: Any)-
Create a new model by parsing and validating input data from keyword arguments.
Raises ValidationError if the input data cannot be parsed to form a valid model.
Expand source code
class WeightedLossFactory(BaseModel): sample_weights: List[List[float]] = Field(default=None) sample_weights_bias: List[float] = Field(default=None) output_weights: List[float] = Field(default=None) zero_channels_weighted: Union[bool, List[bool]] = Field(default=True) label_smoothing: float = 0.0 baseloss: str = "binary_crossentropy" def get_loss(self, from_logits: bool = False, reduction: Literal['AUTO', 'NONE', 'SUM', 'SUM_OVER_BATCH_SIZE'] = "AUTO", name: str = 'weighted_loss'): return WeightedLoss( name=name, reduction=getattr(losses_utils.ReductionV2, reduction), from_logits=from_logits, **self.dict())Ancestors
- pydantic.main.BaseModel
- pydantic.utils.Representation
Class variables
var baseloss : strvar label_smoothing : floatvar output_weights : List[float]var sample_weights : List[List[float]]var sample_weights_bias : List[float]var zero_channels_weighted : Union[bool, List[bool]]
Methods
def get_loss(self, from_logits: bool = False, reduction: typing_extensions.Literal['AUTO', 'NONE', 'SUM', 'SUM_OVER_BATCH_SIZE'] = 'AUTO', name: str = 'weighted_loss')-
Expand source code
def get_loss(self, from_logits: bool = False, reduction: Literal['AUTO', 'NONE', 'SUM', 'SUM_OVER_BATCH_SIZE'] = "AUTO", name: str = 'weighted_loss'): return WeightedLoss( name=name, reduction=getattr(losses_utils.ReductionV2, reduction), from_logits=from_logits, **self.dict())