pytorch转tflite实践

这个世界总是有各种各样的幺蛾子,所以我们要做各种各样的转换,就像今天要写的pytorch模型需要被转换成tflite。下面就以pytorch-ssd模型为例,做一次pytorch转tflite的实践。

  • pth模型转换成onnx
    第一步把torch.save()存下的模型转换成onnx模型,代码如下
  • import torch
    from vision.ssd.mobilenet_v3_ssd_lite import create_mobilenetv3_ssd_lite
    model = create_mobilenetv3_ssd_lite(num_classes=2)
    torch.load("CARN_model_checkpoint.pt",map_location='cpu')['state_dict'].items()},False)
    model.load_state_dict(torch.load("Epoch-85-Loss-0.4889--Epoch-45-Loss-0.4090.pth",map_location='cpu'))
    dummy_input = torch.randn(1,3,300,300)
    input_names = ["input"]
    output_names = ["output"]
    torch.onnx.export(model, dummy_input, "ssd_Epoch-45.onnx", verbose=True, input_names=input_names, output_names=output_names,opset_version=11)
    
  • onnx转换成tensorflow pb模型
    第二步把onnx模型转换成tensorflow pb模型
  • git clone https://github.com/onnx/onnx-tensorflow.git
    cd onnx-tensorflow
    git checkout v1.6.0-tf-1.15
    pip install -e .
    onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb
    

    通过第二步操作就生成了pb模型。

  • 把nchw格式pb模型转换成nhwc格式pb模型
    因为pth和onnx模型都是nchw的layout,转换成pb之后layout没有变,而tflite和tensorflow模型是nhwc的layout格式的,所以需要再增加一步转换,把nchw格式pb模型转换成nhwc格式pb模型,其实原理就是增加tranpose算子,代码如下:
  • import tensorflow as tf
    if not tf.__version__.startswith('1'):
      import tensorflow.compat.v1 as tf
    from tensorflow.python.tools import optimize_for_inference_lib
    graph_def_file = "..\output.pb"
    tf.reset_default_graph()
    graph_def = tf.GraphDef()
    with tf.Session() as sess:
        # Read binary pb graph from file
        with tf.gfile.Open(graph_def_file, "rb") as f:
            data2read = f.read()
            graph_def.ParseFromString(data2read)
        tf.graph_util.import_graph_def(graph_def, name='')
        # Get Nodes
        conv_nodes = [n for n in sess.graph.get_operations() if n.type in ['Conv2D','MaxPool','AvgPool']]
        for n_org in conv_nodes:
            # Transpose input
            assert len(n_org.inputs)==1 or len(n_org.inputs)==2
            org_inp_tens = sess.graph.get_tensor_by_name(n_org.inputs[0].name)
            inp_tens = tf.transpose(org_inp_tens, [0, 2, 3, 1], name=n_org.name +'_transp_input')
            op_inputs = [inp_tens]
            # Get filters for Conv but don't transpose
            if n_org.type == 'Conv2D':
                filter_tens = sess.graph.get_tensor_by_name(n_org.inputs[1].name)
                op_inputs.append(filter_tens)
            # Attributes without data_format, NWHC is default
            atts = {key:n_org.node_def.attr[key] for key in list(n_org.node_def.attr.keys()) if key != 'data_format'}
            if n_org.type in['MaxPool', 'AvgPool','Conv2D']:
                st = atts['strides'].list.i
                stl = [st[0], st[2], st[3], st[1]]
                atts['strides'] = tf.AttrValue(list=tf.AttrValue.ListValue(i=stl))
            if n_org.type in ['MaxPool', 'AvgPool']:
                st = atts['ksize'].list.i
                stl = [st[0], st[2], st[3], st[1]]
                atts['ksize'] = tf.AttrValue(list=tf.AttrValue.ListValue(i=stl))
            # Create new Operation
            #print(n_org.type, n_org.name, list(n_org.inputs), n_org.node_def.attr['data_format'])
            op = sess.graph.create_op(op_type=n_org.type, inputs=op_inputs, name=n_org.name+'_new', dtypes=[tf.float32], attrs=atts) 
            out_tens = sess.graph.get_tensor_by_name(n_org.name+'_new'+':0')
            out_trans = tf.transpose(out_tens, [0, 3, 1, 2], name=n_org.name +'_transp_out')
            assert out_trans.shape == sess.graph.get_tensor_by_name(n_org.name+':0').shape
            # Update Connections
            out_nodes = [n for n in sess.graph.get_operations() if n_org.outputs[0] in n.inputs]
            for out in out_nodes:
                for j, nam in enumerate(out.inputs):
                    if n_org.outputs[0] == nam:
                        out._update_input(j, out_trans)
        # Delete old nodes
        graph_def = sess.graph.as_graph_def()
        for on in conv_nodes:
            graph_def.node.remove(on.node_def)
        # Write graph
        tf.io.write_graph(graph_def, "", graph_def_file.rsplit('.', 1)[0]+'_toco.pb', as_text=False)
    

    第三步后会生成output_toco.pb模型,即为nhwc格式的pb模型。

  • 把nhwc格式的pb模型转换成tflite模型