返回顶部
首页 > 资讯 > 后端开发 > Python >PyTorch实现MNIST数据集手写数字识别详情
  • 746
分享到

PyTorch实现MNIST数据集手写数字识别详情

2024-04-02 19:04:59 746人浏览 独家记忆

Python 官方文档:入门教程 => 点击学习

摘要

目录一、PyTorch是什么?二、程序示例1.引入必要库2.下载数据集3.加载数据集4.搭建CNN模型并实例化5.交叉熵损失函数损失函数及SGD算法优化器6.训练函数7.测试函数8.

前言:

本篇文章基于卷积神经网络CNN,使用PyTorch实现MNIST数据集手写数字识别。

一、PyTorch是什么?

PyTorch 是一个 Torch7 团队开源的 Python 优先的深度学习框架,提供两个高级功能:

  • 强大的 GPU 加速 Tensor 计算(类似 numpy)
  • 构建基于 tape 的自动升级系统上的深度神经网络

你可以重用你喜欢的 python 包,如 numpy、scipy 和 Cython ,在需要时扩展 PyTorch。

二、程序示例

下面案例可供运行参考

1.引入必要库

import torchvision
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

2.下载数据集

这里设置download=True,将会自动下载数据集,并存储在./data文件夹。

train_data = torchvision.datasets.MNIST(root="./data",train=True,transfORM=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.MNIST(root="./data",train=False,transform=torchvision.transforms.ToTensor(),download=True)

3.加载数据集

batch_size=32表示每一个batch中包含32张手写数字图片,shuffle=True表示打乱测试集(data和target仍一一对应)

train_loader = DataLoader(train_data,batch_size=32,shuffle=True)
test_loader = DataLoader(test_data,batch_size=32,shuffle=False)

4.搭建CNN模型并实例化

class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.con1 = torch.nn.Conv2d(1,10,kernel_size=5)
        self.con2 = torch.nn.Conv2d(10,20,kernel_size=5)
        self.pooling = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Linear(320,10)
    def forward(self,x):
        batch_size = x.size(0)
        x = F.relu(self.pooling(self.con1(x)))
        x = F.relu(self.pooling(self.con2(x)))
        x = x.view(batch_size,-1)
        x = self.fc(x)
        return x
#模型实例化        
model = Net()

5.交叉熵损失函数损失函数及SGD算法优化器

lossfun = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)

6.训练函数

def train(epoch):
    running_loss = 0.0
    for i,(inputs,targets) in enumerate(train_loader,0):
        # inputs,targets = inputs.to(device),targets.to(device)
        opt.zero_grad()
        outputs = model(inputs)
        loss = lossfun(outputs,targets)
        loss.backward()
        opt.step()

        running_loss += loss.item()
        if i % 300 == 299:
            print('[%d,%d] loss:%.3f' % (epoch+1,i+1,running_loss/300))
            running_loss = 0.0

7.测试函数

def test():
    total = 0
    correct = 0
    with torch.no_grad():
        for (inputs,targets) in test_loader:
            # inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _,predicted = torch.max(outputs.data,dim=1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    print(100*correct/total)

8.运行

if __name__ == '__main__':
    for epoch in range(20):
        train(epoch)
        test()

三、总结

到此这篇关于PyTorch实现MNIST数据集手写数字识别详情的文章就介绍到这了,更多相关PyTorch MNIST 内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

--结束END--

本文标题: PyTorch实现MNIST数据集手写数字识别详情

本文链接: https://lsjlt.com/news/120304.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • PyTorch实现MNIST数据集手写数字识别详情
    目录一、PyTorch是什么?二、程序示例1.引入必要库2.下载数据集3.加载数据集4.搭建CNN模型并实例化5.交叉熵损失函数损失函数及SGD算法优化器6.训练函数7.测试函数8....
    99+
    2024-04-02
  • pytorch实现mnist手写彩色数字识别
    目录前言一 前期工作1.设置GPU或者cpu2.导入数据二 数据预处理1.加载数据2.可视化数据3.再次检查数据三 搭建网络四 训练模型1.设置学习率2.模型训练五 模型评估1.Lo...
    99+
    2024-04-02
  • pytorch教程实现mnist手写数字识别代码示例
    目录1.构建网络2.编写训练代码3.编写测试代码4.指导程序train和test5.完整代码 1.构建网络 nn.Moudle是pytorch官方指定的编写Net模块,在init函数...
    99+
    2024-04-02
  • Python实战之MNIST手写数字识别详解
    目录数据集介绍1.数据预处理2.网络搭建3.网络配置关于优化器关于损失函数关于指标4.网络训练与测试5.绘制loss和accuracy随着epochs的变化图6.完整代码数据集介绍 ...
    99+
    2024-04-02
  • TensorFlow教程Softmax逻辑回归识别手写数字MNIST数据集
    基于MNIST数据集的逻辑回归模型做十分类任务 没有隐含层的Softmax Regression只能直接从图像的像素点推断是哪个数字,而没有特征抽象的过程。多层神经网络依靠隐含层,则...
    99+
    2024-04-02
  • pytorch实现手写数字图片识别
    本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下 数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备...
    99+
    2024-04-02
  • 手把手教你实现PyTorch的MNIST数据集
    目录概述 获取数据 网络模型 train 函数 test 函数 main 函数 完整代码:概述 MNIST 包含 0~9 的手写数字, 共有 60000 个训练集和 10000 个...
    99+
    2024-04-02
  • TensorFlow中Softmax逻辑回归如何识别手写数字MNIST数据集
    今天就跟大家聊聊有关TensorFlow中Softmax逻辑回归如何识别手写数字MNIST数据集,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。基于MNIST数据集的逻辑回归模型做十分...
    99+
    2023-06-25
  • Python中如何实现MNIST手写数字识别功能
    这篇文章主要为大家展示了“Python中如何实现MNIST手写数字识别功能”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“Python中如何实现MNIST手写数字识别功能”这篇文章吧。数据集介绍M...
    99+
    2023-06-22
  • Python实战小项目之Mnist手写数字识别
    目录程序流程分析图:传播过程:代码展示:创建环境准备数据集下载数据集下载测试集绘制图像搭建神经网络训练模型测试模型保存训练模型运行结果展示:程序流程分析图: 传播过程: 代码展...
    99+
    2024-04-02
  • pytorch如何实现手写数字图片识别
    这篇文章给大家分享的是有关pytorch如何实现手写数字图片识别的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。具体内容如下数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备...
    99+
    2023-06-15
  • caffe的python接口之手写数字识别mnist实例
    目录引言一、数据准备二、导入caffe库,并设定文件路径二、生成配置文件三、生成参数文件solver四、开始训练模型五、完成的python文件引言 深度学习的第一个实例一般都是mni...
    99+
    2024-04-02
  • Java实现BP神经网络MNIST手写数字识别的示例详解
    目录一、神经网络的构建二、系统架构服务器客户端采用MVC架构一、神经网络的构建 (1):构建神经网络层次结构 由训练集数据可知,手写输入的数据维数为784维,而对应的输出结果为分别为...
    99+
    2023-01-31
    Java实现手写数字识别 Java手写数字识别 Java数字识别
  • PyTorch实现手写数字识别的示例代码
    目录加载手写数字的数据数据加载器(分批加载)建立模型模型训练测试集抽取数据,查看预测结果计算模型精度自己手写数字进行预测加载手写数字的数据 组成训练集和测试集,这里已经下载好了,所以...
    99+
    2024-04-02
  • PyTorch简单手写数字识别的实现过程
    目录一、包导入及所需数据的下载关于数据集引入的改动二、进行数据处理变换操作三、数据预览测试和数据装载四、模型搭建和参数优化关于模型搭建的改动总代码:测试总结具体流程: ① 导入相应...
    99+
    2024-04-02
  • 超详细PyTorch实现手写数字识别器的示例代码
    前言 深度学习中有很多玩具数据,mnist就是其中一个,一个人能否入门深度学习往往就是以能否玩转mnist数据来判断的,在前面很多基础介绍后我们就可以来实现一个简单的手写数字识别的网...
    99+
    2024-04-02
  • PyTorch实现手写数字的识别入门小白教程
    目录手写数字识别(小白入门)1.数据预处理2.训练模型3.测试模型,保存4.调用模型5.完整代码手写数字识别(小白入门) 今早刚刚上了节实验课,关于逻辑回归,所以手有点刺挠就想发个博...
    99+
    2024-04-02
  • Python中如何实现MNIST手写体识别
    这篇文章主要介绍Python中如何实现MNIST手写体识别,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!1.实验内容简述1.1 实验环境本实验采用的软硬件实验环境如表所示:在Windows操作系统下,采用基于Tens...
    99+
    2023-06-25
  • PyTorch手写数字数据集进行多分类
    目录一、实现过程0、导包1、准备数据2、设计模型3、构造损失函数和优化器4、训练和测试二、参考文献一、实现过程 本文对经典手写数字数据集进行多分类,损失函数采用交叉熵,激活函数采用R...
    99+
    2024-04-02
  • Pytorch写数字识别LeNet模型
    目录LeNet网络训练结果泛化能力测试LeNet网络 LeNet网络过卷积层时候保持分辨率不变,过池化层时候分辨率变小。实现如下 from PIL import Image imp...
    99+
    2024-04-02
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作