返回顶部
首页 > 资讯 > 精选 >PyTorch如何实现一个简单的CNN图像分类器
  • 775
分享到

PyTorch如何实现一个简单的CNN图像分类器

2023-06-15 07:06:35 775人浏览 八月长安
摘要

这篇文章给大家分享的是有关PyTorch如何实现一个简单的CNN图像分类器的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。一. 加载数据Pytorch的数据加载一般是用torch.utils.data.Datase

这篇文章给大家分享的是有关PyTorch如何实现一个简单的CNN图像分类器的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。

一. 加载数据

Pytorch的数据加载一般是用torch.utils.data.Dataset与torch.utils.data.Dataloader两个类联合进行。我们需要继承Dataset来定义自己的数据集类,然后在训练时用Dataloader加载自定义的数据集类。

1. 继承Dataset类并重写关键方法

pytorch的dataset类有两种:Map-style datasets和Iterable-style datasets。前者是我们常用的结构,而后者是当数据集难以(或不可能)进行随机读取时使用。在这里我们实现Map-style dataset。
继承torch.utils.data.Dataset后,需要重写的方法有:__len__与__getitem__方法,其中__len__方法需要返回所有数据的数量,而__getitem__则是要依照给出的数据索引获取对应的tensor类型的Sample,除了这两个方法以外,一般还需要实现__init__方法来初始化一些变量。话不多说,直接上代码。

'''包括了各种数据集的读取处理,以及图像相关处理方法'''from torch.utils.data import Datasetimport torchimport osimport cv2from Config import mycfgimport randomimport numpy as npclass ImageClassifyDataset(Dataset):    def __init__(self, imagedir, labelfile, classify_num, train=True):    '''    这里进行一些初始化操作。    '''        self.imagedir = imagedir        self.labelfile = labelfile        self.classify_num = classify_num        self.img_list = []        # 读取标签        with open(self.labelfile, 'r') as fp:            lines = fp.readlines()            for line in lines:                filepath = os.path.join(self.imagedir, line.split(";")[0].replace('\\', '/'))                label = line.split(";")[1].strip('\n')                self.img_list.append((filepath, label))        if not train:            self.img_list = random.sample(self.img_list, 50)    def __len__(self):        return len(self.img_list)            def __getitem__(self, item):    '''    这个函数是关键,通过item(索引)来取数据集中的数据,    一般来说在这里才将图像数据加载入内存,之前存的是图像的保存路径    '''        _int_label = int(self.img_list[item][1])# label直接用0,1,2,3,4...表示不同类别        label = torch.tensor(_int_label,dtype=torch.long)        img = self.ProcessImgResize(self.img_list[item][0])        return img, label    def ProcessImgResize(self, filename):    '''    对图像进行一些预处理    '''        _img = cv2.imread(filename)        _img = cv2.resize(_img, (mycfg.IMG_WIDTH, mycfg.IMG_HEIGHT), interpolation=cv2.INTER_CUBIC)        _img = _img.transpose((2, 0, 1))        _img = _img / 255        _img = torch.from_numpy(_img)        _img = _img.to(torch.float32)        return _img

有一些的数据集类一般还会传入一个transfORMs函数来构造一个图像预处理序列,传入transforms函数的一个好处是作为参数传入的话可以对一些非本地数据集中的数据进行操作(比如直接通过torchvision获取的一些预存数据集CIFAR10等等),除此之外就是torchvision.transforms里面有一些预定义的图像操作函数,可以直接像拼积木一样拼成一个图像处理序列,很方便。我这里因为是用我自己下载到本地的数据集,而且比较简单就直接用自己的函数来操作了。

2. 使用Dataloader加载数据

实例化自定义的数据集类ImageClassifyDataset后,将其传给DataLoader作为参数,得到一个可遍历的数据加载器。可以通过参数batch_size控制批处理大小,shuffle控制是否乱序读取,num_workers控制用于读取数据的线程数量。

from torch.utils.data import DataLoaderfrom MyDataset import ImageClassifyDatasetdataset = ImageClassifyDataset(imagedir, labelfile, 10)dataloader = DataLoader(dataset, batch_size=5, shuffle=True,num_workers=5)for index, data in enumerate(dataloader):print(index)# batch索引print(data)# 一个batch的{img,label}

二. 模型设计

在这里只讨论深度学习模型的设计,pytorch中的网络结构是一层一层叠出来的,pytorch中预定义了许多可以通过参数控制的网络层结构,比如Linear、CNN、RNN、Transformer等等具体可以查阅官方文档中的torch.nn部分。
设计自己的模型结构需要继承torch.nn.Module这个类,然后实现其中的forward方法,一般在__init__中设定好网络模型的一些组件,然后在forward方法中依据输入输出顺序拼装组件。

'''包括了各种模型、自定义的loss计算方法、optimizer'''import torch.nn as nnclass Simple_CNN(nn.Module):    def __init__(self, class_num):        super(Simple_CNN, self).__init__()        self.class_num = class_num        self.conv1 = nn.Sequential(            nn.Conv2d(# input: 3,400,600                in_channels=3,                out_channels=8,                kernel_size=5,                stride=1,                padding=2            ),            nn.Conv2d(                in_channels=8,                out_channels=16,                kernel_size=5,                stride=1,                padding=2            ),            nn.AvgPool2d(2),  # 16,400,600 --> 16,200,300            nn.BatchNorm2d(16),            nn.LeakyReLU(),            nn.Conv2d(                in_channels=16,                out_channels=16,                kernel_size=5,                stride=1,                padding=2            ),            nn.Conv2d(                in_channels=16,                out_channels=8,                kernel_size=5,                stride=1,                padding=2            ),            nn.AvgPool2d(2),  # 8,200,300 --> 8,100,150            nn.BatchNorm2d(8),            nn.LeakyReLU(),            nn.Conv2d(                in_channels=8,                out_channels=8,                kernel_size=3,                stride=1,                padding=1            ),            nn.Conv2d(                in_channels=8,                out_channels=1,                kernel_size=3,                stride=1,                padding=1            ),            nn.AvgPool2d(2),  # 1,100,150 --> 1,50,75            nn.BatchNorm2d(1),            nn.LeakyReLU()        )        self.line = nn.Sequential(            nn.Linear(                in_features=50 * 75,                out_features=self.class_num            ),            nn.Softmax()        )    def forward(self, x):        x = self.conv1(x)        x = x.view(-1, 50 * 75)        y = self.line(x)        return y

上面我定义的模型中包括卷积组件conv1和全连接组件line,卷积组件中包括了一些卷积层,一般是按照{卷积层、池化层、激活函数}的顺序拼接,其中我还在激活函数之前添加了一个BatchNorm2d层对上层的输出进行正则化以免传入激活函数的值过小(梯度消失)或过大(梯度爆炸)。
在拼接组件时,由于我全连接层的输入是一个一维向量,所以需要将卷积组件中最后的50 × 75 50\times 7550×75大小的矩阵展平成一维的再传入全连接层(x.view(-1,50*75))

三. 训练

实例化模型后,网络模型的训练需要定义损失函数与优化器,损失函数定义了网络输出与标签的差距,依据不同的任务需要定义不同的合适的损失函数,而优化器则定义了神经网络中的参数如何基于损失来更新,目前神经网络最常用的优化器就是SGD(随机梯度下降算法) 及其变种。
在我这个简单的分类器模型中,直接用的多分类任务最常用的损失函数CrossEntropyLoss()以及优化器SGD。

self.cnnmodel = Simple_CNN(mycfg.CLASS_NUM)self.criterion = nn.CrossEntropyLoss()# 交叉熵,标签应该是0,1,2,3...的形式而不是独热的self.optimizer = optim.SGD(self.cnnmodel.parameters(), lr=mycfg.LEARNING_RATE, momentum=0.9)

训练过程其实很简单,使用dataloader依照batch读出数据后,将input放入网络模型中计算得到网络的输出,然后基于标签通过损失函数计算Loss,并将Loss反向传播回神经网络(在此之前需要清理上一次循环时的梯度),最后通过优化器更新权重。训练部分代码如下:

for each_epoch in range(mycfg.MAX_EPOCH):            running_loss = 0.0            self.cnnmodel.train()            for index, data in enumerate(self.dataloader):                inputs, labels = data                outputs = self.cnnmodel(inputs)                loss = self.criterion(outputs, labels)                self.optimizer.zero_grad()# 清理上一次循环的梯度                loss.backward()# 反向传播                self.optimizer.step()# 更新参数                running_loss += loss.item()                if index % 200 == 199:                    print("[{}] loss: {:.4f}".format(each_epoch, running_loss/200))                    running_loss = 0.0            # 保存每一轮的模型            model_name = 'classify-{}-{}.pth'.format(each_epoch,round(all_loss/all_index,3))            torch.save(self.cnnmodel,model_name)# 保存全部模型

四. 测试

测试和训练的步骤差不多,也就是读取模型后通过dataloader获取数据然后将其输入网络获得输出,但是不需要进行反向传播的等操作了。比较值得注意的可能就是准确率计算方面有一些小技巧。

acc = 0.0count = 0self.cnnmodel = torch.load('mymodel.pth')self.cnnmodel.eval()for index, data in enumerate(dataloader_eval):inputs, labels = data   # 5,3,400,600  5,10count += len(labels)outputs = cnnmodel(inputs)_,predict = torch.max(outputs, 1)acc += (labels == predict).sum().item()print("[{}] accurancy: {:.4f}".format(each_epoch, acc / count))

我这里采用的是保存全部模型并加载全部模型的方法,这种方法的好处是在使用模型时可以完全将其看作一个黑盒,但是在模型比较大时这种方法会很费事。此时可以采用只保存参数不保存网络结构的方法,在每一次使用模型时需要读取参数赋值给已经实例化的模型:

torch.save(cnnmodel.state_dict(), "my_resnet.pth")cnnmodel = Simple_CNN()cnnmodel.load_state_dict(torch.load("my_resnet.pth"))

感谢各位的阅读!关于“PyTorch如何实现一个简单的CNN图像分类器”这篇文章就分享到这里了,希望以上内容可以对大家有一定的帮助,让大家可以学到更多知识,如果觉得文章不错,可以把它分享出去让更多的人看到吧!

--结束END--

本文标题: PyTorch如何实现一个简单的CNN图像分类器

本文链接: https://lsjlt.com/news/278509.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • PyTorch如何实现一个简单的CNN图像分类器
    这篇文章给大家分享的是有关PyTorch如何实现一个简单的CNN图像分类器的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。一. 加载数据Pytorch的数据加载一般是用torch.utils.data.Datase...
    99+
    2023-06-15
  • 基于PyTorch实现一个简单的CNN图像分类器
    目录一. 加载数据1. 继承Dataset类并重写关键方法2. 使用Dataloader加载数据二. 模型设计三. 训练四. 测试结语 pytorch中文网:https://www....
    99+
    2024-04-02
  • Pytorch中如何实现病虫害图像分类
    本篇文章给大家分享的是有关Pytorch中如何实现病虫害图像分类,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。一、pytorch框架1.1、概念PyTorch是一个开源的Pyt...
    99+
    2023-06-22
  • tensorflow2.0如何实现cnn的图像识别
    目录tensorflow2.0实现cnn图像识别cnn+tensorflow实现识别图片总结tensorflow2.0实现cnn图像识别 import tensorflow as ...
    99+
    2022-12-17
    tensorflow2.0 cnn图像识别 tensorflow2.0 cnn图像识别
  • 如何实现一个最简单的vbs类
    这篇文章主要介绍如何实现一个最简单的vbs类,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!class CFoo     sub PrintHell...
    99+
    2023-06-08
  • Pytorch搭建简单的卷积神经网络(CNN)实现MNIST数据集分类任务
    目录关于一些代码里的解释,可以看我上一篇发布的文章,里面有很详细的介绍!!!第一步:基本库的导入第二步:引用MNIST数据集,这里采用的是torchvision自带的MNIST数据集...
    99+
    2023-03-23
    Pytorch卷积神经网络 Pytorch MNIST数据集分类
  • tensorflow+k-means聚类简单实现猫狗图像分类的方法
    目录一、前言二、k-means聚类三、图像分类一、前言 本文使用的是 kaggle 猫狗大战的数据集:https://www.kaggle.com/c/dogs-vs-cats/da...
    99+
    2024-04-02
  • Python实现一个简单的QQ截图
    目录前言一、需求分析二、截图三、矩形选择四、按钮设置总结前言   毕设有一部分要用到类似QQ截图的功能,这里记录制作过程。因为后期要添加人工智能的功能,所以用py...
    99+
    2024-04-02
  • Java如何实现一个简单计算器
    这篇文章主要介绍了Java如何实现一个简单计算器,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。先来看看界面效果:源码如下:package test1; i...
    99+
    2023-06-22
  • 使用Python怎么实现一个图像分类功能
    今天就跟大家聊聊有关使用Python怎么实现一个图像分类功能,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。Python的优点有哪些1、简单易用,与C/C++、Java、C# 等传统语...
    99+
    2023-06-14
  • Keras如何实现图像分类任务
    在Keras中实现图像分类任务通常需要遵循以下步骤: 准备数据集:首先需要准备包含图像和对应标签的数据集。可以使用Keras中的...
    99+
    2024-04-02
  • springboot如何实现一个简单的aop实例
    小编给大家分享一下springboot如何实现一个简单的aop实例,希望大家阅读完这篇文章之后都有所收获,下面让我们一起去探讨吧!简介AOP(Aspect-Oriented Programming:面向切面编程)aop能将一些繁琐、重复、无...
    99+
    2023-06-25
  • 如何实现一个简单的区块链
    这篇文章将为大家详细讲解有关如何实现一个简单的区块链,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。区块链的基础概念很简单:一个分布式数据库,...
    99+
    2024-04-02
  • 如何使用C++编写一个简单的图像识别程序?
    如何使用C++编写一个简单的图像识别程序?在现代科技的发展中,图像识别技术扮演了越来越重要的角色。无论是人脸识别、物体检测还是自动驾驶,图像识别都发挥着关键作用。本文将介绍如何使用C++编写一个简单的图像识别程序,帮助读者了解图像识别的基本...
    99+
    2023-11-03
    简单程序 图像识别 C++编程
  • PaddlePaddle中的图像分类任务如何实现
    在PaddlePaddle中实现图像分类任务通常使用卷积神经网络(CNN)。以下是一个简单的图像分类示例: 导入必要的库和模块: ...
    99+
    2024-04-02
  • Python元类编程实现一个简单的ORM
    目录概述效果步骤结束语完整代码概述 什么是ORM    ORM全称“Object Relational Mapping”,即对象-关系映射,就是把关系数据库的...
    99+
    2023-03-06
    Python元类编程ORM Python ORM
  • Python如何实现简单图像缩放与旋转
    这篇文章主要介绍Python如何实现简单图像缩放与旋转,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!1. 图像缩放1.2. 使用命令import cv2# 缩放def resize(im...
    99+
    2023-06-26
  • Python K-means实现简单图像聚类的示例代码
    这里直接给出第一个版本的直接实现: import os import numpy as np from sklearn.cluster import KMeans import ...
    99+
    2024-04-02
  • Spring实现一个简单的SpringIOC容器
    接触Spring快半年了,前段时间刚用Spring4+S2H4做完了自己的毕设,但是很明显感觉对Spring尤其是IOC容器的实现原理理解的不到位,说白了,就是仅仅停留在会用的阶段,有一颗想读源码的心于是买了一本计文柯的《Spring技术内...
    99+
    2023-05-31
    spring ioc容器 sprin
  • Qt如何实现一个简单的word文档编辑器
    本文小编为大家详细介绍“Qt如何实现一个简单的word文档编辑器”,内容详细,步骤清晰,细节处理妥当,希望这篇“Qt如何实现一个简单的word文档编辑器”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。1.先看效果图...
    99+
    2023-07-02
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作