深度学习部署过程中,有时候需要部署多个模型,每个模型的 接口相近 ,对于常规C++代码,每个模型对应一个类,调用过程中需要 手动 对每个类进行操作,对于程序员来讲,这种不智能的方法难以接收,学过java, python的人应该对 反射机制 有所了解,无奈C++暂还不支持这种方式,实属C++的一个缺陷吧!

但方法总比困难多,反射机制不支持,用别的方式来达到类似的目标,主要采用 设计模型 中的 单例+工厂模式 ,实现 字符串创建类 的方式。此外,由于多种模型接口相似,可以使用基类写法, 其他类继承该基类 ,实现自动操作。

二、原理分析

反射是程序可以访问、检测和修改它本身状态或行为的一种能力,简单地说, 通过字符串来创建类。

第一步:创建单例模板,用于单例工厂的创建

// 单例类模板
template<typename T>
class Singleton
public:
    static T* GetInstance()
        static T instance;
        return &instance;
    Singleton(T&&) = delete;
    Singleton(const T&) = delete;
    void operator= (const T&) = delete;
protected:
    Singleton() = default;
    virtual ~Singleton() = default;

第二步:定义函数指针类型:用于指向创建类实例的回调函数

using CreateObjectFunc = function<void*()>;

第三步:创建工厂类实现类与字符串的映射关系

// 创建对象的回调函数
struct CreateObjectFuncClass {
    explicit CreateObjectFuncClass(CreateObjectFunc func) : create_func(func) {}
    CreateObjectFunc create_func;
// Object工厂类
class ObjectFactory : public Singleton<ObjectFactory> {
public:
    // 返回void *减少了代码的耦合
		// 提供给外部注册以及类创建
    void* CreateObject(const string& class_name) {
        CreateObjectFunc createobj = nullptr;
        if (create_funcs_.find(class_name) != create_funcs_.end())
            createobj = create_funcs_.find(class_name)->second->create_func;
        if (createobj == nullptr)
            return nullptr;
        // 调用函数指针指向的函数 调用REGISTER_CLASS中宏的绑定函数,也就是运行new className代码
        return createobj();
		// 保存类名字符串到类对象构造函数指针的映射
    void RegisterObject(const string& class_name, CreateObjectFunc func) {
        auto it = create_funcs_.find(class_name);
        if (it != create_funcs_.end())
            create_funcs_[class_name]->create_func = func;
            create_funcs_.emplace(class_name, new CreateObjectFuncClass(func));
    ~ObjectFactory() {
        for (auto it : create_funcs_)
            if (it.second != nullptr)
                delete it.second;
                it.second = nullptr;
        create_funcs_.clear();
private:
    // 缓存类名和生成类实例函数指针的map
    unordered_map<string, CreateObjectFuncClass* > create_funcs_;

第四步:宏定义,方便注册

#define REGISTERCLASS(className) \
class className##Helper { \
public: \
    className##Helper() \
        ObjectFactory::GetInstance()->RegisterObject(#className, []() \
            auto* obj = new className(); \
						// 这个可以指定默认的执行的函数
            // obj->SetModelName(#className); \
            return obj; \
        }); \
className##Helper g_##className##_helper;// 初始化一个helper的全局变量,执行构造函数中的RegisterObject执行。 
  • 宏定义中 #:转为字符串##:连接两个字符串\:代码换行符
  • 第五步:定义各个模型基类,其他模型都继承该基类

    class BasicNet{
    public:
    		//构造函数至少要这个,因为register时候使用
    		BasicNet() {}
    		bool LoadModel(const std::string modelPath){
    			//code
    			return true;
    		void SetModelName(std::string modelName) {_modelNname=modelName;}
    		virtual ~BasicNet(){}
    		std::string GetModelName() const {return _modelNname;}
    protected:
    		std::string _modelNname;
    class AlexNet:public BasicNet{
    public:
    		~AlexNet(){}
    class LeNet:public BasicNet{
    public:
    		~LeNet(){}
    

    第六步:测试结果

    int main() {
    	REGISTERCLASS(AlexNet)
    	REGISTERCLASS(LeNet)
    	std::vector<std::string> models{"AlexNet","LeNet"};
    	for(auto model:models){
    		auto alexNet = (BasicNet*)ObjectFactory::GetInstance()->CreateObject(model);
    		alexNet->SetModelName(model);
    		std::string name = alexNet->GetModelName();
        	cout << name.c_str() << endl;
        	delete alexNet;
        return 0;
    ///////////////////// 结果
    AlexNet
    LeNet
    

    三、代码详解

    #include <iostream>
    #include <unordered_map>
    #include <functional>
    #include <vector>
    using namespace std;
    // 单例类模板
    template<typename T>
    class Singleton
    public:
        static T* GetInstance()
            static T instance;
            return &instance;
        Singleton(T&&) = delete;
        Singleton(const T&) = delete;
        void operator= (const T&) = delete;
    protected:
        Singleton() = default;
        virtual ~Singleton() = default;
    using CreateObjectFunc = function<void*()>;
    // 创建对象的回调函数
    struct CreateObjectFuncClass {
        explicit CreateObjectFuncClass(CreateObjectFunc func) : create_func(func) {}
        CreateObjectFunc create_func;
    // Object工厂类
    class ObjectFactory : public Singleton<ObjectFactory> {
    public:
        // 返回void *减少了代码的耦合
    		// 提供给外部注册以及类创建
        void* CreateObject(const string& class_name) {
            CreateObjectFunc createobj = nullptr;
            if (create_funcs_.find(class_name) != create_funcs_.end())
                createobj = create_funcs_.find(class_name)->second->create_func;
            if (createobj == nullptr)
                return nullptr;
            // 调用函数指针指向的函数 调用REGISTER_CLASS中宏的绑定函数,也就是运行new className代码
            return createobj();
    		// 保存类名字符串到类对象构造函数指针的映射
        void RegisterObject(const string& class_name, CreateObjectFunc func) {
            auto it = create_funcs_.find(class_name);
            if (it != create_funcs_.end())
                create_funcs_[class_name]->create_func = func;
                create_funcs_.emplace(class_name, new CreateObjectFuncClass(func));
        ~ObjectFactory() {
            for (auto it : create_funcs_)
                if (it.second != nullptr)
                    delete it.second;
                    it.second = nullptr;
            create_funcs_.clear();
    private:
        // 缓存类名和生成类实例函数指针的map
        unordered_map<string, CreateObjectFuncClass* > create_funcs_;
    #define REGISTERCLASS(className) \
    class className##Helper { \
    public: \
        className##Helper() \
            ObjectFactory::GetInstance()->RegisterObject(#className, []() \
                auto* obj = new className(); \
                return obj; \
            }); \
    className##Helper g_##className##_helper;// 初始化一个helper的全局变量,执行构造函数中的RegisterObject执行。 
    class BasicNet{
    public:
    		BasicNet() {}
    		bool LoadModel(const std::string modelPath){
    			//code
    			return true;
    		void SetModelName(std::string modelName) {_modelNname=modelName;}
    		virtual ~BasicNet(){}
    		std::string GetModelName() const {return _modelNname;}
    protected:
    		std::string _modelNname;
    class AlexNet:public BasicNet{
    public:
    		~AlexNet(){}
    class LeNet:public BasicNet{
    public:
    		~LeNet(){}
    int main() {
    	REGISTERCLASS(AlexNet)
    	REGISTERCLASS(LeNet)
    	std::vector<std::string> models{"AlexNet","LeNet"};
    	for(auto model:models){
    		auto alexNet = (BasicNet*)ObjectFactory::GetInstance()->CreateObject(model);
    		alexNet->SetModelName(model);
    		std::string name = alexNet->GetModelName();
        	cout << name.c_str() << endl;
        	delete alexNet;
        return 0;
    
  • https://zhuanlan.zhihu.com/p/232319083
  • https://blog.csdn.net/caesar1228/article/details/103549797
  • https://cloud.tencent.com/developer/article/1176520
  •