Module brevettiai.data.tf_types
Classes which are serializable to tuples to allow use in tensorflow datasets
Expand source code
"""
Classes which are serializable to tuples to allow use in tensorflow datasets
"""
from dataclasses import dataclass
import tensorflow as tf
@dataclass(frozen=True, order=True)
class TfRange:
"""
An object for slicing tensors
"""
start: int = 0
end: int = 0
def slice(self, sequence):
if self.end == 0:
return sequence[tf.cast(self.start, tf.int32):]
else:
return sequence[tf.cast(self.start, tf.int32):tf.cast(self.end, tf.int32)]
@classmethod
def build(cls, x):
return cls(x[0], x[1])
@classmethod
def build_single_frame(cls, frame):
frame = int(frame)
return cls(frame, frame+1)
def __iter__(self):
yield from (self.start, self.end)
def __str__(self):
return f"SequenceRange{self.start, None if self.end == 0 else self.end}"
Classes
class TfRange (start: int = 0, end: int = 0)
-
An object for slicing tensors
Expand source code
class TfRange: """ An object for slicing tensors """ start: int = 0 end: int = 0 def slice(self, sequence): if self.end == 0: return sequence[tf.cast(self.start, tf.int32):] else: return sequence[tf.cast(self.start, tf.int32):tf.cast(self.end, tf.int32)] @classmethod def build(cls, x): return cls(x[0], x[1]) @classmethod def build_single_frame(cls, frame): frame = int(frame) return cls(frame, frame+1) def __iter__(self): yield from (self.start, self.end) def __str__(self): return f"SequenceRange{self.start, None if self.end == 0 else self.end}"
Class variables
var end : int
var start : int
Static methods
def build(x)
-
Expand source code
@classmethod def build(cls, x): return cls(x[0], x[1])
def build_single_frame(frame)
-
Expand source code
@classmethod def build_single_frame(cls, frame): frame = int(frame) return cls(frame, frame+1)
Methods
def slice(self, sequence)
-
Expand source code
def slice(self, sequence): if self.end == 0: return sequence[tf.cast(self.start, tf.int32):] else: return sequence[tf.cast(self.start, tf.int32):tf.cast(self.end, tf.int32)]