Module brevettiai.data.image.segmentation_loader
Expand source code
import logging
import numpy as np
import tensorflow as tf
from brevettiai.interfaces import vue_schema_utils as vue
from brevettiai.data.image import utils, ImageKeys
import json
log = logging.getLogger(__name__)
class SegmentationLoader(vue.VueSettingsModule):
def __init__(self, classes: list, mapping: dict = None,
image_pipeline=None, sparse=False,
input_key="segmentation_path", output_key="segmentation"):
self.input_key = input_key
self.output_key = output_key
self.classes = classes
self.mapping = mapping
self.sparse = sparse
self._ip = image_pipeline
def set_image_pipeline(self, image_pipeline):
self._ip = image_pipeline
@classmethod
def to_schema(cls, builder, name, ptype, default, **kwargs):
if name in {"input_key", "output_key", "image_pipeline"}:
return
if name == "classes":
kwargs["label"] = "Segmentation classes"
if name == "mapping":
kwargs["label"] = "Segmentation mapping"
return super().to_schema(builder, name, ptype, default, **kwargs)
def build_label_space(self, sparse=None):
sparse = sparse or self.sparse
# Setup possible values
if sparse:
def get_output(x, output=np.arange(len(self.classes))[:, None]):
try:
return output[self.classes.index(x)]
except ValueError:
return None
else:
def get_output(x, output=np.eye(len(self.classes))):
try:
if isinstance(x, str):
x = x.split("|")
return sum(output[self.classes.index(v)] for v in x)
except ValueError:
return None
# Build mapping
if self.mapping:
label_space = {k: get_output(v) for k, v in self.mapping.items()}
else:
label_space = {k: get_output(k) for k in self.classes}
log.info(f"Invalid classes in map: {json.dumps({k: v for k, v in label_space.items() if v is None})}")
label_space = {k: v for k, v in label_space.items() if v is not None}
return label_space
def load_segmentations(self, paths, input_image_shape, metadata):
ip = self._ip
crops_joiner, output_dtype = ip.get_output_spec(ip.rois, ip.roi_mode, dtype=tf.float32)
label_space = self.build_label_space()
segmentation_channels = 1 if self.sparse else len(self.classes)
@tf.function
def _load_segmentation(x):
path, shape, metadata = x
if tf.strings.length(path) > 0:
# Load
img = utils.load_segmentation(path, metadata, shape=(shape[0], shape[1], segmentation_channels),
label_space=label_space, io=ip._io)
# Apply ROIs
crops = utils.roi_selection(img, rois=ip.rois, crops_joiner=crops_joiner)
# Transform crops
crops = [utils.image_view_transform(crop, target_size=ip.target_size,
resize_method="nearest",
keep_aspect_ratio=ip.keep_aspect_ratio,
antialias=ip.antialias,
padding_mode=ip.padding_mode) for crop in crops]
return tuple(crops) if isinstance(output_dtype, tuple) else crops[0]
else:
return tf.zeros((1, 1, segmentation_channels))
segmentations = tf.map_fn(_load_segmentation, [paths, input_image_shape, metadata], dtype=output_dtype)
return segmentations
def __call__(self, x, *args, **kwargs):
metakeys = {ImageKeys.BOUNDING_BOX, ImageKeys.ZOOM}
metadata = {k: x[k] for k in metakeys if k in x}
segmentations = self.load_segmentations(x[self.input_key], x["_image_file_shape"], metadata)
x[self.output_key] = segmentations
return x
Classes
class SegmentationLoader (classes: list, mapping: dict = None, image_pipeline=None, sparse=False, input_key='segmentation_path', output_key='segmentation')
-
Base class for serializable modules
Expand source code
class SegmentationLoader(vue.VueSettingsModule): def __init__(self, classes: list, mapping: dict = None, image_pipeline=None, sparse=False, input_key="segmentation_path", output_key="segmentation"): self.input_key = input_key self.output_key = output_key self.classes = classes self.mapping = mapping self.sparse = sparse self._ip = image_pipeline def set_image_pipeline(self, image_pipeline): self._ip = image_pipeline @classmethod def to_schema(cls, builder, name, ptype, default, **kwargs): if name in {"input_key", "output_key", "image_pipeline"}: return if name == "classes": kwargs["label"] = "Segmentation classes" if name == "mapping": kwargs["label"] = "Segmentation mapping" return super().to_schema(builder, name, ptype, default, **kwargs) def build_label_space(self, sparse=None): sparse = sparse or self.sparse # Setup possible values if sparse: def get_output(x, output=np.arange(len(self.classes))[:, None]): try: return output[self.classes.index(x)] except ValueError: return None else: def get_output(x, output=np.eye(len(self.classes))): try: if isinstance(x, str): x = x.split("|") return sum(output[self.classes.index(v)] for v in x) except ValueError: return None # Build mapping if self.mapping: label_space = {k: get_output(v) for k, v in self.mapping.items()} else: label_space = {k: get_output(k) for k in self.classes} log.info(f"Invalid classes in map: {json.dumps({k: v for k, v in label_space.items() if v is None})}") label_space = {k: v for k, v in label_space.items() if v is not None} return label_space def load_segmentations(self, paths, input_image_shape, metadata): ip = self._ip crops_joiner, output_dtype = ip.get_output_spec(ip.rois, ip.roi_mode, dtype=tf.float32) label_space = self.build_label_space() segmentation_channels = 1 if self.sparse else len(self.classes) @tf.function def _load_segmentation(x): path, shape, metadata = x if tf.strings.length(path) > 0: # Load img = utils.load_segmentation(path, metadata, shape=(shape[0], shape[1], segmentation_channels), label_space=label_space, io=ip._io) # Apply ROIs crops = utils.roi_selection(img, rois=ip.rois, crops_joiner=crops_joiner) # Transform crops crops = [utils.image_view_transform(crop, target_size=ip.target_size, resize_method="nearest", keep_aspect_ratio=ip.keep_aspect_ratio, antialias=ip.antialias, padding_mode=ip.padding_mode) for crop in crops] return tuple(crops) if isinstance(output_dtype, tuple) else crops[0] else: return tf.zeros((1, 1, segmentation_channels)) segmentations = tf.map_fn(_load_segmentation, [paths, input_image_shape, metadata], dtype=output_dtype) return segmentations def __call__(self, x, *args, **kwargs): metakeys = {ImageKeys.BOUNDING_BOX, ImageKeys.ZOOM} metadata = {k: x[k] for k in metakeys if k in x} segmentations = self.load_segmentations(x[self.input_key], x["_image_file_shape"], metadata) x[self.output_key] = segmentations return x
Ancestors
Methods
def build_label_space(self, sparse=None)
-
Expand source code
def build_label_space(self, sparse=None): sparse = sparse or self.sparse # Setup possible values if sparse: def get_output(x, output=np.arange(len(self.classes))[:, None]): try: return output[self.classes.index(x)] except ValueError: return None else: def get_output(x, output=np.eye(len(self.classes))): try: if isinstance(x, str): x = x.split("|") return sum(output[self.classes.index(v)] for v in x) except ValueError: return None # Build mapping if self.mapping: label_space = {k: get_output(v) for k, v in self.mapping.items()} else: label_space = {k: get_output(k) for k in self.classes} log.info(f"Invalid classes in map: {json.dumps({k: v for k, v in label_space.items() if v is None})}") label_space = {k: v for k, v in label_space.items() if v is not None} return label_space
def load_segmentations(self, paths, input_image_shape, metadata)
-
Expand source code
def load_segmentations(self, paths, input_image_shape, metadata): ip = self._ip crops_joiner, output_dtype = ip.get_output_spec(ip.rois, ip.roi_mode, dtype=tf.float32) label_space = self.build_label_space() segmentation_channels = 1 if self.sparse else len(self.classes) @tf.function def _load_segmentation(x): path, shape, metadata = x if tf.strings.length(path) > 0: # Load img = utils.load_segmentation(path, metadata, shape=(shape[0], shape[1], segmentation_channels), label_space=label_space, io=ip._io) # Apply ROIs crops = utils.roi_selection(img, rois=ip.rois, crops_joiner=crops_joiner) # Transform crops crops = [utils.image_view_transform(crop, target_size=ip.target_size, resize_method="nearest", keep_aspect_ratio=ip.keep_aspect_ratio, antialias=ip.antialias, padding_mode=ip.padding_mode) for crop in crops] return tuple(crops) if isinstance(output_dtype, tuple) else crops[0] else: return tf.zeros((1, 1, segmentation_channels)) segmentations = tf.map_fn(_load_segmentation, [paths, input_image_shape, metadata], dtype=output_dtype) return segmentations
def set_image_pipeline(self, image_pipeline)
-
Expand source code
def set_image_pipeline(self, image_pipeline): self._ip = image_pipeline
Inherited members