Module brevettiai.io.onnx

Expand source code
import json
import tensorflow as tf
import tf2onnx
from tf2onnx import optimizer
from brevettiai.utils.profiling import profile_graph
from tensorflow.python.keras.saving import saving_utils as _saving_utils
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2


def input_output_quantization(function, dtype=tf.uint8, output_scaling=255):
    @tf.function(input_signature=[tf.TensorSpec(shape=s.shape, dtype=dtype, name=s.name) for s in
                                  function.input_signature])
    def _input_output_quant_wrapper(x):
        x = tf.cast(x, tf.float32)
        yhat = function(x)
        return {k: tf.transpose(tf.cast(v * output_scaling, dtype=dtype), [0, 3, 1, 2]) for k, v in yhat.items()}
    return _input_output_quant_wrapper


def input_as_nchw(function):
    @tf.function(input_signature=[tf.TensorSpec(shape=tuple(s.shape[x] for x in [0, 3, 1, 2]), dtype=s.dtype, name=s.name)
                                  for s in function.input_signature])
    def _input_output_wrapper(x):
        x_nhwc = tf.transpose(x, [0, 2, 3, 1])
        yhat = function(x_nhwc)
        yhat = {k: tf.transpose(v, [0, 3, 1, 2]) if len(v.shape) == 4 else v for k, v in yhat.items()}
        return yhat
    return _input_output_wrapper


def export_model(model, output_file=None, inputs_as_nchw: (list, bool) = None, shape_override=None,
                 meta_data: dict = None):
    inputs_as_nchw = inputs_as_nchw or []
    shape_override = shape_override or {}
    meta_data = meta_data or {}

    # Ensure metadata is json serializable
    if not isinstance(meta_data, dict):
        meta_data = json.loads(meta_data.json())

    # Create graph representation
    if shape_override:
        input_signature = [tf.TensorSpec(sh, dtype=spec.dtype or tf.float32, name=spec.name)
                           for sh, spec in zip(shape_override, model.input_spec)]
    else:
        input_signature = None

    function = _saving_utils.trace_model_call(model, input_signature)

    # Quantize input
    #function = input_output_quantization(function, dtype=tf.uint8, output_scaling=255)
    if inputs_as_nchw:
       function = input_as_nchw(function)

    # Get concrete function
    concrete_func = function.get_concrete_function()
    concrete_func = convert_variables_to_constants_v2(concrete_func,
                                                      lower_control_flow=False,
                                                      aggressive_inlining=True)

    # allow to pass inputs and outputs from caller if we don't want all of them
    input_names = [input_tensor.name for input_tensor in concrete_func.inputs
                   if input_tensor.dtype != tf.dtypes.resource]
    output_names = [output_tensor.name for output_tensor in concrete_func.outputs
                    if output_tensor.dtype != tf.dtypes.resource]

    frozen_graph = tf2onnx.tf_loader.from_function(concrete_func, input_names, output_names)

    if not isinstance(shape_override, dict):
        shape_override = {n: v for n, v in zip(input_names, shape_override)}

    if inputs_as_nchw is True:
        inputs_as_nchw = input_names

    # Convert graph to ONNX
    graph_def = frozen_graph

    if graph_def is not None:
        with tf.Graph().as_default() as tf_graph:
            tf.import_graph_def(graph_def, name='')

            flops = profile_graph(tf_graph)
            meta_data["total_float_ops"] = flops.total_float_ops

            with tf2onnx.tf_loader.tf_session(graph=tf_graph):
                g = tf2onnx.tfonnx.process_tf_graph(tf_graph,
                                                    opset=11,
                                     #shape_override=shape_override,
                                     input_names=input_names,
                                     output_names=output_names,
                                     #inputs_as_nchw=inputs_as_nchw,
                                     const_node_values=None,
                                     initialized_tables=None)

    onnx_graph = optimizer.optimize_graph(g)
    model_proto = onnx_graph.make_model("converted from {}".format(model.name), external_tensor_storage=None)

    for kk, vv in meta_data.items():
        meta = model_proto.metadata_props.add()
        meta.key = kk
        if not isinstance(vv, str):
            vv = json.dumps(vv)
        meta.value = vv

    # Export graph
    if output_file:
        tf2onnx.utils.save_protobuf(output_file, model_proto)
        return output_file
    else:
        return model_proto

Functions

def export_model(model, output_file=None, inputs_as_nchw: () = None, shape_override=None, meta_data: dict = None)
Expand source code
def export_model(model, output_file=None, inputs_as_nchw: (list, bool) = None, shape_override=None,
                 meta_data: dict = None):
    inputs_as_nchw = inputs_as_nchw or []
    shape_override = shape_override or {}
    meta_data = meta_data or {}

    # Ensure metadata is json serializable
    if not isinstance(meta_data, dict):
        meta_data = json.loads(meta_data.json())

    # Create graph representation
    if shape_override:
        input_signature = [tf.TensorSpec(sh, dtype=spec.dtype or tf.float32, name=spec.name)
                           for sh, spec in zip(shape_override, model.input_spec)]
    else:
        input_signature = None

    function = _saving_utils.trace_model_call(model, input_signature)

    # Quantize input
    #function = input_output_quantization(function, dtype=tf.uint8, output_scaling=255)
    if inputs_as_nchw:
       function = input_as_nchw(function)

    # Get concrete function
    concrete_func = function.get_concrete_function()
    concrete_func = convert_variables_to_constants_v2(concrete_func,
                                                      lower_control_flow=False,
                                                      aggressive_inlining=True)

    # allow to pass inputs and outputs from caller if we don't want all of them
    input_names = [input_tensor.name for input_tensor in concrete_func.inputs
                   if input_tensor.dtype != tf.dtypes.resource]
    output_names = [output_tensor.name for output_tensor in concrete_func.outputs
                    if output_tensor.dtype != tf.dtypes.resource]

    frozen_graph = tf2onnx.tf_loader.from_function(concrete_func, input_names, output_names)

    if not isinstance(shape_override, dict):
        shape_override = {n: v for n, v in zip(input_names, shape_override)}

    if inputs_as_nchw is True:
        inputs_as_nchw = input_names

    # Convert graph to ONNX
    graph_def = frozen_graph

    if graph_def is not None:
        with tf.Graph().as_default() as tf_graph:
            tf.import_graph_def(graph_def, name='')

            flops = profile_graph(tf_graph)
            meta_data["total_float_ops"] = flops.total_float_ops

            with tf2onnx.tf_loader.tf_session(graph=tf_graph):
                g = tf2onnx.tfonnx.process_tf_graph(tf_graph,
                                                    opset=11,
                                     #shape_override=shape_override,
                                     input_names=input_names,
                                     output_names=output_names,
                                     #inputs_as_nchw=inputs_as_nchw,
                                     const_node_values=None,
                                     initialized_tables=None)

    onnx_graph = optimizer.optimize_graph(g)
    model_proto = onnx_graph.make_model("converted from {}".format(model.name), external_tensor_storage=None)

    for kk, vv in meta_data.items():
        meta = model_proto.metadata_props.add()
        meta.key = kk
        if not isinstance(vv, str):
            vv = json.dumps(vv)
        meta.value = vv

    # Export graph
    if output_file:
        tf2onnx.utils.save_protobuf(output_file, model_proto)
        return output_file
    else:
        return model_proto
def input_as_nchw(function)
Expand source code
def input_as_nchw(function):
    @tf.function(input_signature=[tf.TensorSpec(shape=tuple(s.shape[x] for x in [0, 3, 1, 2]), dtype=s.dtype, name=s.name)
                                  for s in function.input_signature])
    def _input_output_wrapper(x):
        x_nhwc = tf.transpose(x, [0, 2, 3, 1])
        yhat = function(x_nhwc)
        yhat = {k: tf.transpose(v, [0, 3, 1, 2]) if len(v.shape) == 4 else v for k, v in yhat.items()}
        return yhat
    return _input_output_wrapper
def input_output_quantization(function, dtype=tf.uint8, output_scaling=255)
Expand source code
def input_output_quantization(function, dtype=tf.uint8, output_scaling=255):
    @tf.function(input_signature=[tf.TensorSpec(shape=s.shape, dtype=dtype, name=s.name) for s in
                                  function.input_signature])
    def _input_output_quant_wrapper(x):
        x = tf.cast(x, tf.float32)
        yhat = function(x)
        return {k: tf.transpose(tf.cast(v * output_scaling, dtype=dtype), [0, 3, 1, 2]) for k, v in yhat.items()}
    return _input_output_quant_wrapper