这是一篇对手册性质的文章,如果你刚好从事AI开发,可以参考这文章来进行模型转换。
Keras转TFLite需要三个过程,
-
Keras 转 Tensorflow
-
固化 Tensorflow 网络到 PB(Protocol Buffer)
-
PB 转 TFLite
Keras 网络构成
Keras网络有一个文件(正常情况)
-
*.h5
它是HDF5格式文件,同时保存了网络结构和网络参数。
Tensorflow 网络的构成
Tensorflow 常见的描述网络结构文件是 ckpt,它有两个文件构成
-
model.ckpt
-
model.ckpt.meta
新版本的 Tensorflow 的 Saver 会默认使用新格式保存,新格式的文件是这几个
-
model.ckpt.data-00000-of-00001
-
model.ckpt.index
-
model.ckpt.meta
Tensorflow自从开源之后就经常有改动,目前还不确定新格式的三个文件是什么作用跟含义。
就暂时以最稳定的老版本格式来解释。
-
model.ckpt
这个文件记录了神经网络上节点的权重信息,也就是节点上 wx+b 的取值。
-
model.ckpt.meta
这个文件主要记录了图结构,也就是神经网络的节点结构。
一个完整的神经网络由这两部分构成,Tensorflow 在保存时除了这两个文件还会在目录下自动生成 checkpoint,
checkpoint的内容如下,它只记录了目录下有哪些网络。
model_checkpoint_path: "squeezenet_model.ckpt"
all_model_checkpoint_paths: "squeezenet_model.ckpt"
Keras 转 Tensorflow
转换过程需要先把网络结构和权重加载到model对象,
然后用 tf.train.Saver 来保存为 ckpt 文件。
目前代码是以V1为基础的,指定Saver版本可以在构建Saver的时候指定参数
saver = tf.train.Saver(write_version=tf.train.SaverDef.V1)
saver.save(K.get_session(), './squeezenet_model.ckpt')
CKPT freeze 到 PB
ckpt的网络结构和权重还是分开的
需要先固化到PB,才能继续转成 tflite。
Tensorflow 提供了python脚本用来固化,位置在
/usr/local/lib/python3.6/site-packages/tensorflow/python/tools/freeze_graph.py
对于固化的过程需要关注这几个参数
-
input_meta_graph: meta 文件,也就是节点结构
-
input_checkpoint: ckpt 文件,保存权重
-
output_graph: 输出PB文件的名称
-
output_node_names: 网络输出节点
-
input_binary: 输入文件是否为二进制
下面的命令直接给出了如何转换,对于几个参数的意义比较难理解的是倒数第二个,文章后面再给出对它的解释。
python3 freeze_graph.py \
--input_meta_graph=model.ckpt.meta \
--input_checkpoint=model.ckpt \
--output_graph=model.pb \
--output_node_names="final_result" \
--input_binary=true
PB 到 Tensorflow Lite
Tensorflow 提供了 TOCO 工具用来做转换,
必填的参数有下面这些,
toco --graph_def_file=squeezenet_model.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=model.tflite \
--inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays=input \
--output_arrays=final_result \
--input_sahpes=1,227,227,3
参数中需要解释的有这几个,
--input_shapes: 输入数据的维度,跟你的网络输入有关。比如1,227,227,3,代表的是1个227*227的3通道图片。
--output_arrays 和 --input_arrays:
这两个参数跟网络的输入输出有关。而 output_arrays 跟转换成 PB 时的参数 --output_node_names 是一样的。
也就是说这两个参数必须在查看网络之后才能确定