相关文章推荐
精明的大白菜  ·  200 OK - HTTP | MDN·  6 月前    · 
酷酷的作业本  ·  Xcode ...·  1 年前    · 
咆哮的梨子  ·  用 pyecharts ...·  1 年前    · 

C++ 部署深度学习模型、STB进行图像预处理

7 个月前 · 来自专栏 MegEngine 实操解析

在上一篇文章中我们完成了深度学习的训练,得到了LeNet训练权重文件。在这一章我们将使用训练权重文件导出静态图模型。并使用C++调用模型完成实际部署的模拟。上篇文章地址为


准备工作

在上一章中,我们提到有四种保存模型的方法,如下表所示。为了训练方便起见,上一章保存了checkpoint文件。但在实际部署中我们经常使用静态图模型,所以我们首先要完成静态图导出。

方法 优劣
保存/加载整个模型 任何情况都不推荐
保存加载模型状态字典 适用于推理,不满足恢复训练要求
保存加载检查点 适用于推理或恢复训练
导出静态图模型 适用于推理,追求高性能部署

导出静态图在MegEngine中有较完整的教程,请参考 导出序列化模型文件(Dump) 。主要分为三步:

  1. 将循环内的前向计算、反向传播和参数优化代码提取成单独的函数,如下面例子中的 train_func()
  2. 将网络所需输入作为训练函数的参数,并返回任意你需要的结果(如输出结果、损失函数值等);
  3. 用 jit 模块中的 trace 装饰器来装饰这个函数,将其中的代码变为静态图代码。

在上一章最后的附录train.py中有dump静态图的方法,代码如下:

from megengine import jit
def dump_mge(pkl_path = "checkpoint.pkl"):
    model = LeNet()
    check_point = megengine.load(pkl_path)
    model.load_state_dict(check_point["state_dict"])
    model.eval()
    @jit.trace(symbolic=True, capture_as_const=True)
    def infer_func(input, *, model):
        pred  = model(input)
        pred_normalized = F.softmax(pred)
        return pred_normalized
    input = megengine.Tensor(np.random.randn(1, 1, 32, 32))
    output = infer_func(input, model=model)
    infer_func.dump("./lenet.mge", arg_names=["input"])

调用dump_mge方法即可完成静态图导出。

C++推理代码

代码的主逻辑为:

  1. 创建Network
  2. 使用load_model()载入模型
  3. 使用stb预处理图片(加载和resize),然后归一化,载入进input tensor
  4. 使用network->forward()和network->wait()完成推理逻辑。
  5. 获取模型输出tensor,并对其进行处理。

推理代码为:

//inference.cpp
#include <iostream>
#include <stdlib.h>
#define STB_IMAGE_IMPLEMENTATION
#include "stb/stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb/stb_image_write.h"
#define STB_IMAGE_RESIZE_IMPLEMENTATION
#define STB_IMAGE_RESIZE_STATIC
#include "stb/stb_image_resize.h"
#include "lite/network.h"
#include "lite/tensor.h"
//注意在这里修改测试图片与所用模型
#define IMAGE_PATH "./test.png"
#define MODEL_PATH "./lenet.mge"
void preprocess_image(std::string pic_path, std::shared_ptr<lite::Tensor> tensor) {
    int width, height, channel;
    uint8_t* image = stbi_load(pic_path.c_str(), &width, &height, &channel, 0);
    printf("Input image %s with height=%d, width=%d, channel=%d\n", pic_path.c_str(),
           width, height, channel);
    auto layout = tensor->get_layout();
    auto pixels = layout.shapes[2] * layout.shapes[3];
    size_t image_size = width * height * channel;
    size_t gray_image_size = width * height * 1;
    unsigned char *gray_image = (unsigned char *)malloc(gray_image_size);
    for(unsigned char *p=image, *pg=gray_image; p!=image+image_size; p+=channel,pg++)
        *pg = uint8_t(*p + *(p+1) + *(p+2))/3.0;
    //! resize to tensor shape
    std::shared_ptr<std::vector<uint8_t>> resize_int8 =
            std::make_shared<std::vector<uint8_t>>(pixels * 1);
    stbir_resize_uint8(
            gray_image, width, height, 0, resize_int8->data(), layout.shapes[2],
            layout.shapes[3], 0, 1);
    free(gray_image);
    stbi_image_free(image);
    //! 减去均值,归一化
    unsigned int sum = 0;
    for(unsigned char *p=gray_image; p!=gray_image+gray_image_size;p++){
    sum += *p;
    sum /= gray_image_size;
    float* in_data = static_cast<float*>(tensor->get_memory_ptr());
    for (size_t i = 0; i < pixels; i++) {
        in_data[i] = resize_int8->at(i)-sum;     
int main()
    //创建网络
    std::shared_ptr<lite::Network> network = std::make_shared<lite::Network>();
    //加载模型
    network->load_model(MODEL_PATH);
    std::shared_ptr<lite::Tensor> input_tensor = network->get_io_tensor("input");
    preprocess_image(IMAGE_PATH, input_tensor);
    //将图片转为Tensor
    network->forward();
    network->wait();
    std::shared_ptr<lite::Tensor> output_tensor = network->get_output_tensor(0);
    float* predict_ptr = static_cast<float*>(output_tensor->get_memory_ptr());
    float max_prob = predict_ptr[0];
    size_t number = 0;
    //寻找最大的标签
    for(size_t i=0; i<10; i++)
        float cur_prob = predict_ptr[i];
        if(cur_prob>max_prob)
            max_prob = cur_prob;
            number = i;
    std::cout << "the predict number is :" << number << std::endl;
    return 0;