深度学习模型部署 C++学习经历

前言

在部署大规模深度学习应用的时候,要想满足应用需求或者压榨模型的性能,C++可能是比python更好的选择方案。基于此,特地记录最近的C++的学习经历。其实以终为始来思考为什么学习C++,首先是为了能够很好地提升模型的性能,满足应用场景中的高可用,高并发,低时延等要求。为了提升模型的性能,需要用到一些推理框架,如TensorRT、NCNN或者Openvino(本文中以TensorRT作为案例)。TensorRT在8.0以上的版本都支持Python的API了,但还是有必要学习C++。另外在模型压缩的时候也会考虑用到C++。

确定推理框架后,然后确定这个框架需要什么格式的模型?这里面可能要提到模型集大成者ONNX,因此需要学习ONNX模型。学习如何将Tensorflow、Pytorch或者keras的模型转换成ONNX,它支持什么算子,这些都是需要学习的。

最后总结这个路线是Tensorflow或pytorch模型转换成ONNX,然后ONNX对模型进行优化,转换成TensorRT模型优化以及C++的推理。

  1. ONNX模型转换和优化
  2. TensorRT转换和优化
  3. C++对ONNX模型的推理和TensorRT推理实现。

ONNX

关于ONNX的转换可以参考我的git仓库:onnx模型转换。当中包括Tensorflow和Pytorch的模型转换Demo。同时包括用onnxruntime进行推理的过程。在这一块基本的转换过程已经转换了,后续需要更加深化,了解支持的算子,如何转换复杂模型,甚至如何写算子等都要学会。

C++

C++需要学习基础知识,这些都不在话下了。看下书,学习视频,以下用一个C++调用ONNX模型推理作为例子,具体代码可以参考: ONNX C++

在上面博客中使用的是OpenCV自带的DNN推理框架,有时间比较下各种推理框架的优势与劣势。YOLOX的推理主要完成三个步骤:模型加载、图像预处理以及结果后处理。定义如下的头文件:

ONNX 推理头文件:

#include <assert.h>
#include<onnxruntime_cxx_api.h>
#include<ctime>
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/videoio.hpp>
#include <opencv2/highgui.hpp>
class yoloxmodelinference {
public:
	yoloxmodelinference(const wchar_t* onnx_model_path);
	float* predict_test(std::vector<float>input_tensor_values, int batch_size = 1);
	cv::Mat predict(cv::Mat& input_tensor, int batch_size = 1, int index = 0);
	std::vector<float> predict(std::vector<float>& input_data, int batch_size = 1, int index = 0);
private:
	Ort::Env env;
	Ort::Session session;
	Ort::AllocatorWithDefaultOptions allocator;
	std::vector<const char*>input_node_names;
	std::vector<const char*>output_node_names;
	std::vector<int64_t> input_node_dims;
	std::vector<int64_t> output_node_dims;
	std::size_t num_output_nodes;
	std::size_t num_input_nodes;
	const int netWidth = 640;
	const int netHeight = 640;
	const int strideSize = 3;//stride size
	float boxThreshold = 0.25;
#endif // !yoloxmodel

DNN 推理的头文件:

#pragma once
#include<iostream>
#include<opencv2/opencv.hpp>
struct Output {
	int id;
	//置信度
	float confidence;
	//矩形框
	cv::Rect box;
class YOLO {
public:
	YOLO() {
	~YOLO(){}
	bool initModel(cv::dnn::Net& net, std::string& netPath, bool isCuda);
	std::vector<Output>& Detect(cv::Mat& image, cv::dnn::Net& net);
private:
	//网络输入的shape
	const int netWidth = 640;   //ONNX图片输入宽度
	const int netHeight = 640;  //ONNX图片输入高度
	const int strideSize = 3;   //stride size
	float boxThreshold = 0.25;
	float classThreshold = 0.25;
	float nmsThreshold = 0.45;
	float nmsScoreThreshold = boxThreshold * classThreshold;
};

头文件可以看作是一个“配置文件”,里面声明函数和一些固定的参数。

DNN的读取文件非常简单,如下所示。同时也非常清晰看到DNN是可以使用cuda的。

bool YOLO::initModel(Net& net, string& netPath, bool isCuda)
	try {
		net = readNet(netPath);
	catch (const exception& e) {
		cout << e.what() << std::endl;
		return false;
	//cuda
	//if (isCuda) {
	//	net.setPreferableBackend(cv::dnn::DNN_BACKEND_CUDA);
	//	net.setPreferableTarget(cv::dnn::DNN_TARGET_CUDA_FP16);
	net.setPreferableBackend(DNN_BACKEND_DEFAULT);
	net.setPreferableTarget(DNN_TARGET_CPU);
	return true;
}

ONNX 会比较麻烦一点,说麻烦一点,其实是没有认真学习当中的 API

yoloxmodelinference::yoloxmodelinference(const wchar_t* onnx_model_path):session(nullptr), env(nullptr) {
    //初始化环境,每个进程一个环境,环境保留了线程池和其他状态信息
    this->env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "yolox");
    //初始化Session选项
    Ort::SessionOptions session_options;
    session_options.SetInterOpNumThreads(1);
    session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
    // 创建Session并把模型加载到内存中
    this->session = Ort::Session(env, onnx_model_path, session_options);
    //输入输出节点数量和名称
    this->num_input_nodes = session.GetInputCount();
    this->num_output_nodes = session.GetOutputCount();
    for (int i = 0; i < this->num_input_nodes; i++)
        auto input_node_name = session.GetInputName(i, allocator);
        this->input_node_names.push_back(input_node_name);
        Ort::TypeInfo type_info = session.GetInputTypeInfo(i);
        auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
        ONNXTensorElementDataType type = tensor_info.GetElementType();
        this->input_node_dims = tensor_info.GetShape();