Module brevettiai.io.tf_recorder

Expand source code
import tensorflow as tf
from brevettiai import Module

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()  # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def serialize_composite_structure(value):
    try:
        return _bytes_feature(tf.io.serialize_tensor(value))
    except ValueError:
        if isinstance(value, dict):
            return tf.train.Features(feature={k: serialize_composite_structure(v) for k, v in value.items()})


def generate_dtype_structure(value):
    if isinstance(value, type(tf.constant(0))):
        return value.dtype
    if isinstance(value, dict):
        return {k: generate_dtype_structure(v) for k, v in value.items()}
    else:
        return tf.constant(value).dtype


class TfRecorder(Module):
    def __init__(self, filenames, structure=None, compression_type="GZIP"):
        self.filenames = filenames
        self.structure = structure
        self.compression_type = compression_type

        self.writer = None

    def set_structure_from_example(self, value):
        self.structure = generate_dtype_structure(value)
        return self.structure

    @property
    def feature_description(self):
        return tf.nest.pack_sequence_as(
            flat_sequence=[tf.io.FixedLenFeature((), tf.string)] * len(tf.nest.flatten(self.structure)),
            structure=self.structure)

    def __enter__(self):
        options = tf.io.TFRecordOptions(compression_type=self.compression_type)
        self.writer = tf.io.TFRecordWriter(self.filenames, options)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.writer.__exit__(exc_type, exc_val, exc_tb)
        self.writer = None

    def write(self, value):
        if self.writer is None:
            raise ValueError("writer not open, use context manager or __enter__ / __exit__ functions")
        self.writer.write(self.serialize(value).SerializeToString())

    def serialize(self, value):
        if self.structure is None:
            self.set_structure_from_example(value)
        return tf.train.Example(features=serialize_composite_structure(value))

    def get_dataset(self, *args, **kwargs):
        if self.structure is None:
            raise ValueError("Structure not known, set before loading dataset")

        ds = tf.data.TFRecordDataset(filenames=self.filenames, compression_type=self.compression_type, *args, **kwargs)
        return ds.map(self.parse_dataset)

    def parse_dataset(self, x):
        x = tf.io.parse_single_example(x, self.feature_description)
        parsed = [tf.io.parse_tensor(x, dtype) for dtype, x in zip(tf.nest.flatten(self.structure), tf.nest.flatten(x))]
        return tf.nest.pack_sequence_as(flat_sequence=parsed, structure=x)

    def get_config(self):
        cfg = super().get_config()
        cfg["structure"] = tf.nest.map_structure(lambda x: x.name, cfg["structure"])
        return cfg

    @classmethod
    def from_config(cls, config):
        config["structure"] = tf.nest.map_structure(tf.dtypes.as_dtype, config["structure"])
        return super().from_config(config)

Functions

def generate_dtype_structure(value)
Expand source code
def generate_dtype_structure(value):
    if isinstance(value, type(tf.constant(0))):
        return value.dtype
    if isinstance(value, dict):
        return {k: generate_dtype_structure(v) for k, v in value.items()}
    else:
        return tf.constant(value).dtype
def serialize_composite_structure(value)
Expand source code
def serialize_composite_structure(value):
    try:
        return _bytes_feature(tf.io.serialize_tensor(value))
    except ValueError:
        if isinstance(value, dict):
            return tf.train.Features(feature={k: serialize_composite_structure(v) for k, v in value.items()})

Classes

class TfRecorder (filenames, structure=None, compression_type='GZIP')

Base class for serializable modules

Expand source code
class TfRecorder(Module):
    def __init__(self, filenames, structure=None, compression_type="GZIP"):
        self.filenames = filenames
        self.structure = structure
        self.compression_type = compression_type

        self.writer = None

    def set_structure_from_example(self, value):
        self.structure = generate_dtype_structure(value)
        return self.structure

    @property
    def feature_description(self):
        return tf.nest.pack_sequence_as(
            flat_sequence=[tf.io.FixedLenFeature((), tf.string)] * len(tf.nest.flatten(self.structure)),
            structure=self.structure)

    def __enter__(self):
        options = tf.io.TFRecordOptions(compression_type=self.compression_type)
        self.writer = tf.io.TFRecordWriter(self.filenames, options)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.writer.__exit__(exc_type, exc_val, exc_tb)
        self.writer = None

    def write(self, value):
        if self.writer is None:
            raise ValueError("writer not open, use context manager or __enter__ / __exit__ functions")
        self.writer.write(self.serialize(value).SerializeToString())

    def serialize(self, value):
        if self.structure is None:
            self.set_structure_from_example(value)
        return tf.train.Example(features=serialize_composite_structure(value))

    def get_dataset(self, *args, **kwargs):
        if self.structure is None:
            raise ValueError("Structure not known, set before loading dataset")

        ds = tf.data.TFRecordDataset(filenames=self.filenames, compression_type=self.compression_type, *args, **kwargs)
        return ds.map(self.parse_dataset)

    def parse_dataset(self, x):
        x = tf.io.parse_single_example(x, self.feature_description)
        parsed = [tf.io.parse_tensor(x, dtype) for dtype, x in zip(tf.nest.flatten(self.structure), tf.nest.flatten(x))]
        return tf.nest.pack_sequence_as(flat_sequence=parsed, structure=x)

    def get_config(self):
        cfg = super().get_config()
        cfg["structure"] = tf.nest.map_structure(lambda x: x.name, cfg["structure"])
        return cfg

    @classmethod
    def from_config(cls, config):
        config["structure"] = tf.nest.map_structure(tf.dtypes.as_dtype, config["structure"])
        return super().from_config(config)

Ancestors

Static methods

def from_config(config)
Expand source code
@classmethod
def from_config(cls, config):
    config["structure"] = tf.nest.map_structure(tf.dtypes.as_dtype, config["structure"])
    return super().from_config(config)

Instance variables

var feature_description
Expand source code
@property
def feature_description(self):
    return tf.nest.pack_sequence_as(
        flat_sequence=[tf.io.FixedLenFeature((), tf.string)] * len(tf.nest.flatten(self.structure)),
        structure=self.structure)

Methods

def get_config(self)
Expand source code
def get_config(self):
    cfg = super().get_config()
    cfg["structure"] = tf.nest.map_structure(lambda x: x.name, cfg["structure"])
    return cfg
def get_dataset(self, *args, **kwargs)
Expand source code
def get_dataset(self, *args, **kwargs):
    if self.structure is None:
        raise ValueError("Structure not known, set before loading dataset")

    ds = tf.data.TFRecordDataset(filenames=self.filenames, compression_type=self.compression_type, *args, **kwargs)
    return ds.map(self.parse_dataset)
def parse_dataset(self, x)
Expand source code
def parse_dataset(self, x):
    x = tf.io.parse_single_example(x, self.feature_description)
    parsed = [tf.io.parse_tensor(x, dtype) for dtype, x in zip(tf.nest.flatten(self.structure), tf.nest.flatten(x))]
    return tf.nest.pack_sequence_as(flat_sequence=parsed, structure=x)
def serialize(self, value)
Expand source code
def serialize(self, value):
    if self.structure is None:
        self.set_structure_from_example(value)
    return tf.train.Example(features=serialize_composite_structure(value))
def set_structure_from_example(self, value)
Expand source code
def set_structure_from_example(self, value):
    self.structure = generate_dtype_structure(value)
    return self.structure
def write(self, value)
Expand source code
def write(self, value):
    if self.writer is None:
        raise ValueError("writer not open, use context manager or __enter__ / __exit__ functions")
    self.writer.write(self.serialize(value).SerializeToString())