Module brevettiai.io.tf_recorder
Expand source code
import tensorflow as tf
from brevettiai import Module
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def serialize_composite_structure(value):
try:
return _bytes_feature(tf.io.serialize_tensor(value))
except ValueError:
if isinstance(value, dict):
return tf.train.Features(feature={k: serialize_composite_structure(v) for k, v in value.items()})
def generate_dtype_structure(value):
if isinstance(value, type(tf.constant(0))):
return value.dtype
if isinstance(value, dict):
return {k: generate_dtype_structure(v) for k, v in value.items()}
else:
return tf.constant(value).dtype
class TfRecorder(Module):
def __init__(self, filenames, structure=None, compression_type="GZIP"):
self.filenames = filenames
self.structure = structure
self.compression_type = compression_type
self.writer = None
def set_structure_from_example(self, value):
self.structure = generate_dtype_structure(value)
return self.structure
@property
def feature_description(self):
return tf.nest.pack_sequence_as(
flat_sequence=[tf.io.FixedLenFeature((), tf.string)] * len(tf.nest.flatten(self.structure)),
structure=self.structure)
def __enter__(self):
options = tf.io.TFRecordOptions(compression_type=self.compression_type)
self.writer = tf.io.TFRecordWriter(self.filenames, options)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.writer.__exit__(exc_type, exc_val, exc_tb)
self.writer = None
def write(self, value):
if self.writer is None:
raise ValueError("writer not open, use context manager or __enter__ / __exit__ functions")
self.writer.write(self.serialize(value).SerializeToString())
def serialize(self, value):
if self.structure is None:
self.set_structure_from_example(value)
return tf.train.Example(features=serialize_composite_structure(value))
def get_dataset(self, *args, **kwargs):
if self.structure is None:
raise ValueError("Structure not known, set before loading dataset")
ds = tf.data.TFRecordDataset(filenames=self.filenames, compression_type=self.compression_type, *args, **kwargs)
return ds.map(self.parse_dataset)
def parse_dataset(self, x):
x = tf.io.parse_single_example(x, self.feature_description)
parsed = [tf.io.parse_tensor(x, dtype) for dtype, x in zip(tf.nest.flatten(self.structure), tf.nest.flatten(x))]
return tf.nest.pack_sequence_as(flat_sequence=parsed, structure=x)
def get_config(self):
cfg = super().get_config()
cfg["structure"] = tf.nest.map_structure(lambda x: x.name, cfg["structure"])
return cfg
@classmethod
def from_config(cls, config):
config["structure"] = tf.nest.map_structure(tf.dtypes.as_dtype, config["structure"])
return super().from_config(config)
Functions
def generate_dtype_structure(value)
-
Expand source code
def generate_dtype_structure(value): if isinstance(value, type(tf.constant(0))): return value.dtype if isinstance(value, dict): return {k: generate_dtype_structure(v) for k, v in value.items()} else: return tf.constant(value).dtype
def serialize_composite_structure(value)
-
Expand source code
def serialize_composite_structure(value): try: return _bytes_feature(tf.io.serialize_tensor(value)) except ValueError: if isinstance(value, dict): return tf.train.Features(feature={k: serialize_composite_structure(v) for k, v in value.items()})
Classes
class TfRecorder (filenames, structure=None, compression_type='GZIP')
-
Base class for serializable modules
Expand source code
class TfRecorder(Module): def __init__(self, filenames, structure=None, compression_type="GZIP"): self.filenames = filenames self.structure = structure self.compression_type = compression_type self.writer = None def set_structure_from_example(self, value): self.structure = generate_dtype_structure(value) return self.structure @property def feature_description(self): return tf.nest.pack_sequence_as( flat_sequence=[tf.io.FixedLenFeature((), tf.string)] * len(tf.nest.flatten(self.structure)), structure=self.structure) def __enter__(self): options = tf.io.TFRecordOptions(compression_type=self.compression_type) self.writer = tf.io.TFRecordWriter(self.filenames, options) return self def __exit__(self, exc_type, exc_val, exc_tb): self.writer.__exit__(exc_type, exc_val, exc_tb) self.writer = None def write(self, value): if self.writer is None: raise ValueError("writer not open, use context manager or __enter__ / __exit__ functions") self.writer.write(self.serialize(value).SerializeToString()) def serialize(self, value): if self.structure is None: self.set_structure_from_example(value) return tf.train.Example(features=serialize_composite_structure(value)) def get_dataset(self, *args, **kwargs): if self.structure is None: raise ValueError("Structure not known, set before loading dataset") ds = tf.data.TFRecordDataset(filenames=self.filenames, compression_type=self.compression_type, *args, **kwargs) return ds.map(self.parse_dataset) def parse_dataset(self, x): x = tf.io.parse_single_example(x, self.feature_description) parsed = [tf.io.parse_tensor(x, dtype) for dtype, x in zip(tf.nest.flatten(self.structure), tf.nest.flatten(x))] return tf.nest.pack_sequence_as(flat_sequence=parsed, structure=x) def get_config(self): cfg = super().get_config() cfg["structure"] = tf.nest.map_structure(lambda x: x.name, cfg["structure"]) return cfg @classmethod def from_config(cls, config): config["structure"] = tf.nest.map_structure(tf.dtypes.as_dtype, config["structure"]) return super().from_config(config)
Ancestors
Static methods
def from_config(config)
-
Expand source code
@classmethod def from_config(cls, config): config["structure"] = tf.nest.map_structure(tf.dtypes.as_dtype, config["structure"]) return super().from_config(config)
Instance variables
var feature_description
-
Expand source code
@property def feature_description(self): return tf.nest.pack_sequence_as( flat_sequence=[tf.io.FixedLenFeature((), tf.string)] * len(tf.nest.flatten(self.structure)), structure=self.structure)
Methods
def get_config(self)
-
Expand source code
def get_config(self): cfg = super().get_config() cfg["structure"] = tf.nest.map_structure(lambda x: x.name, cfg["structure"]) return cfg
def get_dataset(self, *args, **kwargs)
-
Expand source code
def get_dataset(self, *args, **kwargs): if self.structure is None: raise ValueError("Structure not known, set before loading dataset") ds = tf.data.TFRecordDataset(filenames=self.filenames, compression_type=self.compression_type, *args, **kwargs) return ds.map(self.parse_dataset)
def parse_dataset(self, x)
-
Expand source code
def parse_dataset(self, x): x = tf.io.parse_single_example(x, self.feature_description) parsed = [tf.io.parse_tensor(x, dtype) for dtype, x in zip(tf.nest.flatten(self.structure), tf.nest.flatten(x))] return tf.nest.pack_sequence_as(flat_sequence=parsed, structure=x)
def serialize(self, value)
-
Expand source code
def serialize(self, value): if self.structure is None: self.set_structure_from_example(value) return tf.train.Example(features=serialize_composite_structure(value))
def set_structure_from_example(self, value)
-
Expand source code
def set_structure_from_example(self, value): self.structure = generate_dtype_structure(value) return self.structure
def write(self, value)
-
Expand source code
def write(self, value): if self.writer is None: raise ValueError("writer not open, use context manager or __enter__ / __exit__ functions") self.writer.write(self.serialize(value).SerializeToString())