返回顶部
首页 > 资讯 > 后端开发 > Python >python中的Pytorch建模流程汇总
  • 302
分享到

python中的Pytorch建模流程汇总

2024-04-02 19:04:59 302人浏览 八月长安

Python 官方文档:入门教程 => 点击学习

摘要

目录1导入库2设置初始值3导入并制作数据集4定义神经网络架构5定义训练流程6训练模型本节内容学习帮助大家梳理神经网络训练的架构。 一般我们训练神经网络有以下步骤: 导入库设置训练参数

本节内容学习帮助大家梳理神经网络训练的架构。

一般我们训练神经网络有以下步骤:

  • 导入库
  • 设置训练参数的初始值
  • 导入数据集并制作数据集
  • 定义神经网络架构
  • 定义训练流程
  • 训练模型

推荐文章:

python实现可视化大屏

分享4款 Python 自动数据分析神器

以下,我就将上述步骤使用代码进行注释讲解:

1 导入库

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader, DataLoader
import torchvision
import torchvision.transfORMs as transforms

2 设置初始值

# 学习率
lr = 0.15
# 优化算法参数
gamma = 0.8
# 每次小批次训练个数
bs = 128
# 整体数据循环次数
epochs = 10

3 导入并制作数据集

本次我们使用FashionMNIST图像数据集,每个图像是一个28*28的像素数组,共有10个衣物类别,比如连衣裙、运动鞋、包等。

注:初次运行下载需要等待较长时间。

# 导入数据集
mnist = torchvision.datasets.FashionMNIST(
    root = './Datastes'
    , train = True
    , download = True
    , transform = transforms.ToTensor())
    
# 制作数据集
batchdata = DataLoader(mnist
                       , batch_size = bs
                       , shuffle = True
                       , drop_last = False)

我们可以对数据进行检查:

for x, y in batchdata:
    print(x.shape)
    print(y.shape)
    break

# torch.Size([128, 1, 28, 28])
# torch.Size([128])

可以看到一个batch中有128个样本,每个样本的维度是1*28*28。

之后我们确定模型的输入维度与输出维度:

# 输入的维度
input_ = mnist.data[0].numel()
# 784

# 输出的维度
output_ = len(mnist.targets.unique())
# 10

4 定义神经网络架构

先使用一个128个神经元的全连接层,然后用relu激活函数,再将其结果映射到标签的维度,并使用softmax进行激活。

# 定义神经网络架构
class Model(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear1 = nn.Linear(in_features, 128, bias = True)
        self.output = nn.Linear(128, out_features, bias = True)
    
    def forward(self, x):
        x = x.view(-1, 28*28)
        sigma1 = torch.relu(self.linear1(x))
        sigma2 = F.log_softmax(self.output(sigma1), dim = -1)
        return sigma2

5 定义训练流程

在实际应用中,我们一般会将训练模型部分封装成一个函数,而这个函数可以继续细分为以下几步:

  • 定义损失函数与优化器
  • 完成向前传播
  • 计算损失
  • 反向传播
  • 梯度更新
  • 梯度清零

在此六步核心操作的基础上,我们通常还需要对模型的训练进度、损失值与准确度进行监视。

注释代码如下:

# 封装训练模型的函数
def fit(net, batchdata, lr, gamma, epochs):
# 参数:模型架构、数据、学习率、优化算法参数、遍历数据次数

    # 5.1 定义损失函数
    criterion = nn.NLLLoss()
    # 5.1 定义优化算法
    opt = optim.SGD(net.parameters(), lr = lr, momentum = gamma)
    
    # 监视进度:循环之前,一个样本都没有看过
    samples = 0
    # 监视准确度:循环之前,预测正确的个数为0
    corrects = 0
    
    # 全数据训练几次
    for epoch in range(epochs):
        # 对每个batch进行训练
        for batch_idx, (x, y) in enumerate(batchdata):
            # 保险起见,将标签转为1维,与样本对齐
            y = y.view(x.shape[0])
            
            # 5.2 正向传播
            sigma = net.forward(x)
            # 5.3 计算损失
            loss = criterion(sigma, y)
            # 5.4 反向传播
            loss.backward()
            # 5.5 更新梯度
            opt.step()
            # 5.6 梯度清零
            opt.zero_grad()
            
            # 监视进度:每训练一个batch,模型见过的数据就会增加x.shape[0]
            samples += x.shape[0]
            
            # 求解准确度:全部判断正确的样本量/已经看过的总样本量
            # 得到预测标签
            yhat = torch.max(sigma, -1)[1]
            # 将正确的加起来
            corrects += torch.sum(yhat == y)
            
            # 每200个batch和最后结束时,打印模型的进度
            if (batch_idx + 1) % 200 == 0 or batch_idx == (len(batchdata) - 1):
                # 监督模型进度
                print("Epoch{}:[{}/{} {: .0f}%], Loss:{:.6f}, Accuracy:{:.6f}".format(
                    epoch + 1
                    , samples
                    , epochs*len(batchdata.dataset)
                    , 100*samples/(epochs*len(batchdata.dataset))
                    , loss.data.item()
                    , float(100.0*corrects/samples)))

6 训练模型

# 设置随机种子
torch.manual_seed(51)

# 实例化模型
net = Model(input_, output_)

# 训练模型
fit(net, batchdata, lr, gamma, epochs)
# Epoch1:[25600/600000  4%], Loss:0.524430, Accuracy:69.570312
# Epoch1:[51200/600000  9%], Loss:0.363422, Accuracy:74.984375
# ......
# Epoch10:[600000/600000  100%], Loss:0.284664, Accuracy:85.771835

现在我们已经用PyTorch训练了最基础的神经网络,并且可以查看其训练成果。大家可以将代码复制进行运行!

虽然没有用到复杂的模型,但是我们在每次建模时的基本思想都是一致的

到此这篇关于python中的Pytorch建模流程汇总的文章就介绍到这了,更多相关Pytorch建模流程内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

--结束END--

本文标题: python中的Pytorch建模流程汇总

本文链接: https://lsjlt.com/news/141125.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • python中的Pytorch建模流程汇总
    目录1导入库2设置初始值3导入并制作数据集4定义神经网络架构5定义训练流程6训练模型本节内容学习帮助大家梳理神经网络训练的架构。 一般我们训练神经网络有以下步骤: 导入库设置训练参数...
    99+
    2024-04-02
  • python中的Pytorch建模流程是什么
    小编给大家分享一下python中的Pytorch建模流程是什么,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!一般我们训练神经网络有以下步骤:导入库设置训练参数的初...
    99+
    2023-06-29
  • 常用python编程模板汇总
    在我们编程时,有一些代码是固定的,例如Socket连接的代码,读取文件内容的代码,一般情况下我都是到网上搜一下然后直接粘贴下来改一改,当然如果你能自己记住所有的代码那更厉害,但是自己写毕竟不如粘贴来的快,而...
    99+
    2022-06-04
    模板 常用 python
  • Python中psutil模块使用汇总
    简介:psutil(进程和系统实用程序)是一个跨平台库,用于检索Python中运行进程和系统利用率(CPU、内存、磁盘、网络、传感器)的信息。它主要用于系统监视、分析和限制进程资源以...
    99+
    2024-04-02
  • python中常用的内置模块汇总
    内置模块(一) Python内置的模块有很多,我们也已经接触了不少相关模块,接下来咱们就来做一些汇总和介绍。 内置模块有很多 & 模块中的功能也非常多,我们是没有办法注意全局...
    99+
    2024-04-02
  • python中的json模块常用方法汇总
    目录一、概述二、方法详解1.dump()2.dumps3.load4.loads三、代码实战1.dumps()2.dump()4.loads()一、概述 推荐使用参考网站: json...
    99+
    2024-04-02
  • Python卸载模块的方法汇总
    easy_install 卸载 通过easy_install 安装的模块可以直接通过 easy_install -m PackageName 卸载,然后删除Python27Libsite-packages...
    99+
    2022-06-04
    模块 方法 Python
  • Java流程控制语句最全汇总(中篇)
    目录前文Java switch case语句详解switch 语句格式switchcasedefaultbreak例 1例 2嵌套 switch 语句if 语句和 switch 语句...
    99+
    2023-01-13
    Java流程控制语句 流程控制语句 流程控制语句结构
  • 使用数学软件Matlab建模画图程序汇总
    目录1. 二维数据曲线图1.1 绘制二维曲线的基本函数1.plot()函数2.含多个输入参数的plot函数3.含选项的plot函数4.双纵坐标函数plotyy1.2 绘制图形的辅助操...
    99+
    2024-04-02
  • pytorch教程之网络的构建流程笔记
    目录构建网络定义一个网络loss FunctionBackprop更新权值参考网址 构建网络 我们可以通过torch.nn包来构建网络,现在你已经看过了autograd,nn在aut...
    99+
    2024-04-02
  • 关于Python中异常(Exception)的汇总
    前言 Exception类是常用的异常类,该类包括StandardError,StopIteration, GeneratorExit, Warning等异常类。python中的异常使用继承结构创建,可以在...
    99+
    2022-06-04
    异常 Python Exception
  • MySQL中创建表的三种方法汇总
    目录CREATE TABLECREATE TABLE … LIKECREATE TABLE … SELECT总结SQL 标准使用 CREATE TABLE 语句创建数据表;mysql ...
    99+
    2023-02-18
    MySQL创建表 MySQL创建表的方法 MySQL表创建
  • python中的编码知识整理汇总
    问题 在平时工作中,遇到了这样的错误: UnicodeDecodeError: 'ascii' codec can't decode byte 想必大家也都碰到过,很常见 。于是决定对python的...
    99+
    2022-06-04
    知识 python
  • Python中pyautogui库的使用方法汇总
    目录常用操作鼠标操作键盘操作弹窗操作图像操作在使用Python做脚本的话,有两个库可以使用,一个为PyUserInput库,另一个为pyautogui库。就本人而言,我更喜欢使用py...
    99+
    2024-04-02
  • Python中列表的基本操作汇总
    目录1、列表的创建与遍历2、添加元素2.1、append()方法2.2、extend()方法2.3、insert()方法3、删除元素3.1、del命令3.2、pop()方法3.3、r...
    99+
    2024-04-02
  • Python下opencv库的安装过程及问题汇总
    本文主要内容是python下opencv库的安装过程,涉及我在安装时遇到的问题,并且,将从网上搜集并试用的一些解决方案进行了简单的汇总,记录下来。 由于记录的是我第一次安装opencv库的过程,所以内容涵盖可能不全面...
    99+
    2022-06-02
    Python安装opencv库 Python opencv库安装
  • MGR测试过程中出现的问题汇总
    MGR出现的问题大概总结为以下几点: 1.每次提交事务时尽量控制单次操作事务的数据量,减少大事物在其他节点check的时间和堵塞后面的操作带来的集群复制延迟,如事务回滚影响更大; 2.MGR集群环...
    99+
    2024-04-02
  • Python中执行调用JS的多种方法汇总
    1. 写在前面   做爬虫的人大家都知道,现在国内Web或App普遍防护都做的很好,且越有价值的网站这方面越强 再小再弱的网站现在或多或少都要整点反爬 JS在反爬中应用非常广泛,现在做爬虫工程师基本...
    99+
    2023-08-31
    python javascript
  • Python网络编程中urllib2模块的用法总结
    一、最基础的应用 import urllib2 url = r'http://www.baidu.com' html = urllib2.urlopen(url).read() print html...
    99+
    2022-06-04
    网络编程 模块 Python
  • PyTorch深度学习模型的保存和加载流程详解
    一、模型参数的保存和加载  torch.save(module.state_dict(), path):使用module.state_dict()函数获取各层已经...
    99+
    2024-04-02
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作