PyTorch的开源图像分类算法框架pytorch-image-models,功能完善,集成大量数据增强方法和主流的网络框架,同时易用。

  • pytorch-image-models:https://github.com/rwightman/pytorch-image-models

使用方法:

  • 训练脚本:train.py
  • 数据集:训练train,验证val,标签文件夹+图像模式

训练脚本:

python3 train.py ./mydata/clz_dataset/ --dataset clz_dataset --train-split train --val-split val --num-classes 6 --batch-size 24 --input-size 3 336 336
  • 第1个参数是数据集路径;
  • --dataset是数据集名称
  • --train-split是训练集名称
  • --val-split是验证集名称
  • --num-classes是输出类别,默认与模型相同,例如默认resnet50的类别是ImageNet的1000个
  • --batch_size是batch_size,默认是128,显卡单卡2080+224x224大约是48
  • --input-size是输入尺寸,3维,例如修改为3 336 336
nohup python3 -u train.py ../MobileNetV3-Doc-Clz/mydata/document_dataset_v2_1/ --dataset document_dataset_v2_1 --train-split train --val-split val --num-classes 6 --batch-size 24 --input-size 3 336 336 > nohup.out &

输出模型,默认位于./output/train/中,显示当前模型效果:

测试逻辑如下:

  • 使用已训练的模型,如model_best_c6_20210914.pth.tar
  • 设置训练模型时相对应的网络结构,如resnet50,设置类别数
  • 下载预训练模型resnet50_ram-a26f946b.pth,放置于/Users/xxx/.cache/torch/hub/checkpoint
  • 预测结果,需要先转cpu()再转numpy(),避免在GPU环境下报错
  • 添加图像预处理的参数:图像尺寸input_size、差值方式interpolation、均值std和方差mean、裁剪参数crop_pct,与训练的参数保持一致。注意裁剪参数crop_pct设置为1.0,即不裁剪。

参考:https://github.com/rwightman/pytorch-image-models/blob/master/docs/models/resnet.md

添加pth.tar参数模型转换为pt模型的函数save_pt(),使用jit形式。pt模型 = pth.tar参数模型 + 网络结构(如resnet50)。使用pt模型,可以简化使用方式,同时也方便转换为trt模型,进行轻量级部署。在转换函数中,包含验证逻辑,保证转换前后的模型效果一致,即输出不变。

源码,面向对象的推理类:

#!/usr/bin/env python
# -- coding: utf-8 --
Created by C. L. Wang on 15.9.21
import os
import cv2
import numpy as np
import torch
from PIL import Image
from torch.nn import functional as F
import timm
from myutils.project_utils import download_url_img, mkdir_if_not_exist
from root_dir import DATA_DIR
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class ImgPredictor(object):
    def __init__(self, model_path, base_net, num_classes):
        print('[Info] ------ 预测图像 ------')
        self.model_path = model_path
        self.num_classes = num_classes
        self.model, self.transform = self.load_model(self.model_path, base_net, self.num_classes)
        print('[Info] 模型路径: {}'.format(self.model_path))
        print('[Info] base_net: {}'.format(base_net))
        print('[Info] num_classes: {}'.format(num_classes))
    @staticmethod
    def load_model(model_path, base_net, num_classes):
        model = timm.create_model(model_name=base_net, pretrained=False,
                                  checkpoint_path=model_path, num_classes=num_classes)
        if torch.cuda.is_available():
            print('[Info] cuda on!!!')
            model = model.cuda()
        model.eval()
        config_dict = {
            "input_size": (3, 336, 336),
            "interpolation": "bicubic",
            "mean": (0.485, 0.456, 0.406),
            "std": (0.229, 0.224, 0.225),
            "crop_pct": 1.0  # 不进行Crop
        config = resolve_data_config(config_dict, model=model)
        print("[Info] config: {}".format(config))
        transform = create_transform(**config)
        return model, transform
    @staticmethod
    def preprocess_img(img_rgb, transform):
        预处理图像
        img_pil = Image.fromarray(img_rgb.astype('uint8')).convert('RGB')
        img_tensor = transform(img_pil).unsqueeze(0)  # transform and add batch dimension
        if torch.cuda.is_available():
            img_tensor = img_tensor.cuda()
        return img_tensor
    def predict_img(self, img_rgb):
        预测RGB图像
        print('[Info] 预测图像尺寸: {}'.format(img_rgb.shape))
        img_tensor = self.preprocess_img(img_rgb, self.transform)
        print('[Info] 模型输入: {}'.format(img_tensor.shape))
        with torch.no_grad():
            out = self.model(img_tensor)
        print('[Info] 模型结果raw: {}'.format(out))
        probabilities = F.softmax(out[0], dim=0)
        print('[Info] 模型结果: {}'.format(probabilities.shape))
        if self.num_classes >= 5:
            top_n = 5
        else:
            top_n = self.num_classes
        top_prob, top_catid = torch.topk(probabilities, top_n)
        top_catid = list(top_catid.cpu().numpy())
        top_prob = list(top_prob.cpu().numpy())
        top_prob = np.around(top_prob, 4)
        print('[Info] 预测类别: {}'.format(top_catid))
        print('[Info] 预测概率: {}'.format(top_prob))
        return top_catid, top_prob
    def predict_img_path(self, img_path):
        预测图像路径
        img_bgr = cv2.imread(img_path)
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        top_catid, top_prob = self.predict_img(img_rgb)
        return top_catid, top_prob
    def predict_img_url(self, img_url):
        预测图像URL
        _, img_bgr = download_url_img(img_url)
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        top_catid, top_prob = self.predict_img(img_rgb)
        return top_catid, top_prob
    @staticmethod
    def convert_catid_2_label(catid_list, label_list):
        预测类别id转换为str
        str_list = [label_list[int(ci)] for ci in catid_list]
        return str_list
    def save_pt(self, pt_folder_path):
        print('[Info] pt存储开始')
        mkdir_if_not_exist(pt_folder_path)
        model_name = self.model_path.split("/")[-1].split(".")[0]
        print('[Info] 模型名称: {}'.format(model_name))
        dummy_shape = (1, 3, 336, 336)  # 不影响模型
        print('[Info] dummy_shape: {}'.format(dummy_shape))
        if torch.cuda.is_available():
            model_type = "cuda"
        else:
            model_type = "cpu"
        print('[Info] model_type: {}'.format(model_type))
        dummy_input = torch.empty(dummy_shape,
                                  dtype=torch.float32,
                                  device=torch.device(model_type))
        traced = torch.jit.trace(self.model, dummy_input)
        pt_path = os.path.join(pt_folder_path, "{}_{}.pt".format(model_name, model_type))
        traced.save(pt_path)
        with torch.no_grad():
            standard_out = self.model(dummy_input)
        print('[Info] standard_out: {}'.format(standard_out))
        reload_script = torch.jit.load(pt_path)
        with torch.no_grad():
            script_output = reload_script(dummy_input)
        print('[Info] script_output: {}'.format(script_output))
        print('[Info] 验证 is equal: {}'.format(F.l1_loss(standard_out, script_output)))
        print('[Info] 存储完成: {}'.format(pt_path))
        return pt_path
def main():
    img_path = os.path.join(DATA_DIR, "document_dataset_mini", "train", "000", "train_040000_000.jpg")
    # img_path = os.path.join(DATA_DIR, "document_dataset_mini", "train", "001", "train_060000_001.jpg")
    # img_path = os.path.join(DATA_DIR, "document_dataset_mini", "train", "002", "train_020000_002.jpg")
    # img_path = os.path.join(DATA_DIR, "document_dataset_mini", "train", "003", "train_100000_003.jpg")
    # img_path = os.path.join(DATA_DIR, "document_dataset_mini", "train", "004", "train_000000_004.jpg")
    # img_path = os.path.join(DATA_DIR, "document_dataset_mini", "train", "005", "train_080000_005.jpg")
    case_url = "http://quark-cv-data.oss-cn-hangzhou.aliyuncs.com/gaoyan/project/gt_imaage_for_biaozhu3/" \
               "O1CN0100fHnP21yK9SLVNC9_!!6000000007053-0-quark.jpg"
    model_path = os.path.join(DATA_DIR, "models", "model_best_c2_20210915.pth.tar")
    base_net = "resnet50"
    num_classes = 2
    label_list = ["纸质文档", "其他"]
    # show_img_bgr(cv2.imread(img_path))
    me = ImgPredictor(model_path, base_net, num_classes)
    # top5_catid, top5_prob = me.predict_img_path(img_path)
    top5_catid, top5_prob = me.predict_img_url(case_url)
    top5_cat = me.convert_catid_2_label(top5_catid, label_list)
    print('[Info] 预测类别: {}'.format(top5_cat))
    print('[Info] 预测概率: {}'.format(top5_prob))
    # me.save_pt(os.path.join(DATA_DIR, "pt_models"))  # 存储PT模型
if __name__ == '__main__':
    main()
[Info] ------ 预测图像 ------
[Info] config: {'input_size': (3, 336, 336), 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'crop_pct': 1.0}
[Info] 模型路径: /Users/wang/workspace/pytorch-image-models-my/mydata/models/model_best_c2_20210915.pth.tar
[Info] base_net: resnet50
[Info] num_classes: 2
[Info] 预测图像尺寸: (3587, 1842, 3)
[Info] 模型输入: torch.Size([1, 3, 336, 336])
[Info] 模型结果: torch.Size([2])
[Info] 预测类别: [0, 1]
[Info] 预测概率: [0.9196 0.0804]
[Info] 预测类别: ['纸质文档', '其他']
[Info] 预测概率: [0.9196 0.0804]

                    PyTorch的开源图像分类算法框架pytorch-image-models,功能完善,集成大量数据增强方法和主流的网络框架,同时易用。pytorch-image-models:https://github.com/rwightman/pytorch-image-models训练使用方法:训练脚本:train.py数据集:训练train,验证val,标签文件夹+图像模式训练脚本:python3 train.py ./mydata/clz_dataset/ --dataset clz
				
PyTorch中检索CNN图像:在PyTorch中训练和评估CNN以进行图像检索 这是一个Python工具箱,用于实现对本文所述方法的培训和测试: 无需人工注释即可对CNN图像进行微调, RadenovićF.,Tolias G.,Chum O.,TPAMI 2018 [ ] CNN图像检索从BoW获悉:无监督的微调,并附有困难的示例, RadenovićF.,Tolias G.,Chum O.,ECCV 2016 [ ] 它是什么? 该代码实现: 训练(微调)CNN进行图像检索 学习CNN图像表示的监督美白 在牛津和巴黎数据集上测试CNN图像检索 为了运行此工具箱,您将需要: Python3(在Debian 8.1上使用Python 3.7.0进行了测试) PyTorch深度学习框架(已通过1.0.0版测试) 其余所有(数据+网络)将通过我们的脚本自动下载 scripts/utils/prepare_data.sh 默认情况下,这假定您的ImageNet训练集已下载到此目录中的根文件夹data中,并将以128x128像素分辨率准备缓存的HDF5。 在scripts文件夹中,有多个bash脚本,它们将训练具有不同批处理大小的BigGAN。 此代码假定您无权访问完整的TPU吊舱,并因此通过使用梯度累积(对多个迷你批次的平均
timm是由Ross Wightman创建的深度学习库,是一个关于SOTA的计算机视觉模型、层、实用工具、optimizers, schedulers, data-loaders, augmentations,可以复现ImageNet训练结果的训练/验证代码。 代码网址:https://github.com/rwightman/pytorch-image-models 简略文档:https://rwightman.github.io/pytorch-image-models/ 详细文档:https://fa
计算机视觉模型库–Pytorch Image Models (timm) "timm"是由Ross Wightman创建的深度学习库,是一个关于SOTA的计算机视觉模型、层、实用工具、optimizers, schedulers, data-loaders, augmentations,可以复现ImageNet训练结果的训练/验证代码。 Install pip install timm import timm import torch model = timm.creat
基于VGG16网络模型对图像进行识别分类 随着人工智能热潮的发展,图像识别已经成为了其中非常重要的一部分。图像识别是指计算机对图像进行处理、分析,以识别其中所含目标的类别及其位置(即目标检测和分类)的技术。其中图像分类是图像识别的一个类,是给定一幅测试图像,利用训练好的分类器判定它所属的类别。 该项目分为三部分: 第一部分:系统驱动的安装与环境的搭建 第二部分:利用VGG16网络进行模型训练与预...
pytorch图像分类实践 在学习pytorch的过程中我找到了关于图像分类的很浅显的一个教程上一次做的是pytorch的手写数字图片识别是灰度图片,这次是彩色图片的分类,觉得对于像我这样的刚刚开始入门pytorch的小白来说很有意义,今天写篇关于这个图像分类的博客. 收获的知识 1.torchvison 在深度学习中数据加载及预处理是非常复杂繁琐的,但PyTorch提供了一些可极大简化和加快数...
人脑可以很容易地识别和区分图像中的物体。例如,给定猫和狗的图像,在纳秒之内,我们就能区分它们,我们的大脑也能感知到这种差异。如果一台机器模仿这种行为,它就和我们能得到的人工智能一样接近。随后,计算机视觉领域的目标是模仿人类视觉系统——在这方面,已经有许多里程碑突破了障碍。 此外,如今的机器可以轻松区分不同的图像,检测物体和人脸,甚至生成不存在的人的图像!很迷人,不是吗?当我开始使用计算机视觉时,我的第一次经历是图像分类。机器区分物体的能力带来了更多的研究途径——比如区分人。 迁移学习的出现进一步加速
4.编写验证/测试代码 一、 定义数据类 pytorch中提供了两个类用于训练数据的加载,分别是torch.utils.data.Dataset和 torch.utils.data.DataLoader。 Dataset...
### 回答1: BERT-NER-PyTorch是一个基于PyTorch深度学习框架的BERT命名实体识别(NER)模型。BERT是一种在大规模未标记文本上训练的预训练模型,它可以用于各种自然语言处理任务。 BERT-NER-PyTorch利用已经使用大量标记数据进行预训练的BERT模型的表示能力,进行命名实体识别任务。命名实体识别是指从文本中识别特定实体,如人名、地名、组织、日期等。通过使用BERT-NER-PyTorch,我们可以利用预训练的BERT模型来提高命名实体识别的性能。 BERT-NER-PyTorch的实现基于PyTorch深度学习框架PyTorch是一个用于构建神经网络的开源框架,具有易于使用、动态计算图和高度灵活的特点。通过在PyTorch环境下使用BERT-NER-PyTorch,我们可以灵活地进行模型训练、调整和部署。 使用BERT-NER-PyTorch,我们可以通过以下步骤进行命名实体识别: 1. 预处理:将文本数据转换为适合BERT模型输入的格式,例如分词、添加特殊标记等。 2. 模型构建:使用BERT-NER-PyTorch构建NER模型,该模型包括BERT预训练模型和适当的输出层。 3. 模型训练:使用标记的命名实体识别数据对NER模型进行训练,通过最小化损失函数来优化模型参数。 4. 模型评估:使用验证集或测试集评估训练得到的NER模型的性能,例如计算准确率、召回率和F1分数等指标。 5. 模型应用:使用训练好的NER模型对新的文本数据进行命名实体识别,识别出关键实体并提供相应的标签。 总之,BERT-NER-PyTorch是一个基于PyTorch的BERT命名实体识别模型,通过利用预训练的BERT模型的表示能力,在命名实体识别任务中提供了灵活、高效和准确的解决方案。 ### 回答2: bert-ner-pytorch是一个基于PyTorch框架的BERT命名实体识别模型。BERT是一种基于Transformer架构的预训练模型,在自然语言处理任务中取得了很好的效果。NER代表命名实体识别,是一项重要的自然语言处理任务,旨在从文本中识别和标注出特定类型的命名实体,如人名、地点、组织等。 bert-ner-pytorch利用预训练的BERT模型作为输入,结合神经网络模型进行命名实体识别。它通过将输入文本转化为BERT模型能够接受的格式,并在其上进行微调训练来提高NER的性能。具体来说,该模型首先使用BERT模型对文本进行编码,将文本中的每个单词转化为其对应的向量表示。然后,这些向量通过一层或多层的神经网络模型,以预测每个单词是否属于某个命名实体类别。 利用bert-ner-pytorch模型,我们可以将其应用于各种实际场景中,如信息抽取、问题回答、智能问答系统等。通过对输入文本进行命名实体识别,我们可以更好地理解文本中所包含的实体信息,从而为后续的处理与分析提供更多的潜在价值。 需要注意的是,bert-ner-pytorch模型是一个基础的NER模型,它需要根据具体的任务和数据进行进一步的训练和优化。同时,BERT模型本身也有一些限制,如较高的计算资源要求和模型大小。因此,在实际使用时,我们可能需要结合具体需求,对模型进行调整和优化,以适应不同的场景和数据。
yumu_2004: 因为 下面大一大堆都是cout,是有缓冲区的 而deleter里面的是clog,没有缓冲区。所以在程序最后去输出cout缓冲区里面的内容之前 clog已经先把内容输出了 (迟来的回答) PyTorch笔记 - R-Drop、Conv2d、3x3+1x1+identity算子融合 qq_52423671: 最后那一部分是重参数化吧 LeetCode - 1049 最后一块石头的重量 II (0-1背包) CSDN-Ada助手: 恭喜你,获得了 2023 博客之星评选的入围资格,请看这个帖子 (https://bbs.csdn.net/topics/615582855?utm_source=blogger_star_comment)。 请在这里提供反馈: https://blogdev.blog.csdn.net/article/details/129986459?utm_source=blogger_star_comment。 PSP - AlphaFold2 中 Monomer MSA 特征的源码简读 (1) CSDN-Ada助手: 一定要坚持创作更多高质量博客哦, 小小红包, 以资鼓励, 更多创作活动请看: 上传ChatGPT/计算机论文等资源,瓜分¥5000元现金: https://blog.csdn.net/VIP_Assistant/article/details/130196121?utm_source=csdn_ai_ada_redpacket 新人首创任务挑战赛: https://marketing.csdn.net/p/90a06697f3eae83aabea1e150f5be8a5?utm_source=csdn_ai_ada_redpacket 可持续能源技术真的能改变世界吗?: https://activity.csdn.net/creatActivity?id=10425?utm_source=csdn_ai_ada_redpacket 全部创作活动: https://mp.csdn.net/mp_blog/manage/creative?utm_source=csdn_ai_ada_redpacket