返回顶部
首页 > 资讯 > 后端开发 > JAVA >深度学习——LSTM解决分类问题
  • 572
分享到

深度学习——LSTM解决分类问题

深度学习lstm分类 2023-09-08 08:09:08 572人浏览 泡泡鱼
摘要

RNN基本介绍 概述 循环神经网络(Recurrent Neural Network,RNN)是一种深度学习模型,主要用于处理序列数据,如文本、语音、时间序列等具有时序关系的数据。 核心思想 RNN的

RNN基本介绍

概述

循环神经网络(Recurrent Neural Network,RNN)是一种深度学习模型,主要用于处理序列数据,如文本、语音、时间序列等具有时序关系的数据。

核心思想

RNN的关键思想是引入了循环结构,允许信息在网络内部进行传递。与传统的前馈神经网络(Feedforward Neural Network)不同,RNN在处理序列数据时会保留并利用先前的信息来影响后续的输出。

基本结构

RNN的基本结构是一个被称为“循环单元”(recurrent unit)的模块,它接收输入和先前的隐藏状态,并生成输出和新的隐藏状态。循环单元中的权重参数在时间步之间是共享的,这意味着它可以对序列中的不同位置应用相同的操作。

计算过程

RNN在每个时间步的计算过程如下:
1.接收当前时间步的输入和先前时间步的隐藏状态。
2.使用这些输入和隐藏状态计算当前时间步的输出。
3.更新隐藏状态,以便在下一个时间步使用。

优点

由于RNN具有循环结构,它可以在处理序列数据时保持记忆,并捕捉到序列中的长期依赖关系。这使得RNN在许多任务中表现出色,例如语言建模、机器翻译、语音识别、情感分析等。

缺点

然而,传统的RNN在处理长期依赖时存在梯度消失或梯度爆炸的问题,导致难以捕捉到远距离的依赖关系。

LSTM基本介绍

概述

LSTM(Long Short-Term Memory,长短期记忆网络)是一种循环神经网络(RNN)的改进型结构,用于解决传统RNN中的长期依赖问题。相比于传统的RNN,LSTM引入了门控机制,能够更好地捕捉和处理序列数据中的长期依赖关系。

核心思想

LSTM的核心思想是引入了三个门控单元:输入门(Input Gate)、遗忘门(Forget Gate)和输出门(Output Gate)。这些门控单元允许LSTM网络选择性地保留或丢弃信息,并且在传递信息时能够有效地控制梯度的流动。

基本结构

以下是LSTM中各个门控单元的功能:
1.输入门(Input Gate):决定当前时间步的输入信息中哪些部分需要被记忆。它使用sigmoid函数来产生一个0到1之间的值,描述了每个输入的重要性。
2.遗忘门(Forget Gate):决定之前的隐藏状态中哪些信息需要被遗忘。通过使用sigmoid函数,遗忘门可以控制先前的隐藏状态在当前时间步的重要性。
3.输出门(Output Gate):根据当前时间步的输入和之前的隐藏状态,决定应该输出多少信息到下一个时间步。输出门使用sigmoid函数来控制隐藏状态中的信息量,并使用tanh函数来生成当前时间步的输出。

优点

通过使用这些门控单元,LSTM网络能够在处理序列数据时灵活地控制信息的流动和记忆的保留。这使得LSTM能够更好地处理长期依赖关系,并在各种序列建模任务中表现出色,例如机器翻译、语音识别、文本生成等。

代码与详细注释

import torchfrom torch import nnimport torchvision.datasets as dsetsimport torchvision.transfORMs as transformsimport matplotlib.pyplot as plt# 可复现# torch.manual_seed(1)    # reproducible# Hyper ParametersEPOCH = 1               # train the training data n times, to save time, we just train 1 epoch# 批大小BATCH_SIZE = 64TIME_STEP = 28          # rnn time step / image heightINPUT_SIZE = 28         # rnn input size / image widthLR = 0.01               # learning rateDOWNLOAD_MNIST = True   # set to True if haven't download the data# Mnist digital datasettrain_data = dsets.MNIST(    root='./mnist/',    train=True,                         # this is training data    transform=transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to            # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]    download=DOWNLOAD_MNIST,            # download it if you don't have it)# plot one exampleprint(train_data.train_data.size())     # (60000, 28, 28)print(train_data.train_labels.size())   # (60000)plt.imshow(train_data.train_data[0].numpy(), cmap='gray')plt.title('%i' % train_data.train_labels[0])plt.show()# Data Loader for easy mini-batch return in trainingtrain_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)# convert test data into Variable, pick 2000 samples to speed up testingtest_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255.   # shape (2000, 28, 28) value in range(0,1)test_y = test_data.test_labels.numpy()[:2000]    # covert to numpy arrayclass RNN(nn.Module):    def __init__(self):        super(RNN, self).__init__()        self.rnn = nn.LSTM(         # if use nn.RNN(), it hardly learns            input_size=INPUT_SIZE,            hidden_size=64,         # rnn hidden unit            num_layers=1,           # number of rnn layer            batch_first=True,       # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)        )        self.out = nn.Linear(64, 10)    def forward(self, x):        # 输入向量的形状        # x shape (batch, time_step, input_size)        # r_out shape (batch, time_step, output_size)        # h_n shape (n_layers, batch, hidden_size)        # h_c shape (n_layers, batch, hidden_size)        r_out, (h_n, h_c) = self.rnn(x, None)   # None represents zero initial hidden state        # choose r_out at the last time step        # 选择输出最后一步的r_out        out = self.out(r_out[:, -1, :])        return outrnn = RNN()print(rnn)optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)   # optimize all cnn parametersloss_func = nn.CrossEntropyLoss()                       # the target label is not one-hotted# training and testingfor epoch in range(EPOCH):    for step, (b_x, b_y) in enumerate(train_loader):        # gives batch data        b_x = b_x.view(-1, 28, 28)              # reshape x to (batch, time_step, input_size)        output = rnn(b_x)   # rnn output        loss = loss_func(output, b_y)                   # cross entropy loss        optimizer.zero_grad()                           # clear gradients for this training step        loss.backward()     # backpropagation, compute gradients        optimizer.step()    # apply gradients        # 每训练50步之后,测试一下准确度        if step % 50 == 0:            test_output = rnn(test_x)                   # (samples, time_step, input_size)            pred_y = torch.max(test_output, 1)[1].data.numpy()            accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)# print 10 predictions from test datatest_output = rnn(test_x[:10].view(-1, 28, 28))pred_y = torch.max(test_output, 1)[1].data.numpy()print(pred_y, 'prediction number')print(test_y[:10], 'real number')

运行结果

在这里插入图片描述
在这里插入图片描述

来源地址:https://blog.csdn.net/Elon15/article/details/131751184

--结束END--

本文标题: 深度学习——LSTM解决分类问题

本文链接: https://lsjlt.com/news/399564.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • 深度学习——LSTM解决分类问题
    RNN基本介绍 概述 循环神经网络(Recurrent Neural Network,RNN)是一种深度学习模型,主要用于处理序列数据,如文本、语音、时间序列等具有时序关系的数据。 核心思想 RNN的...
    99+
    2023-09-08
    深度学习 lstm 分类
  • 机器深度学习二分类电影的情感问题
    二分类问题可能是应用最广泛的机器学习问题。今天我们将学习根据电影评论的文字内容将其划分为正面或负面。 一、数据集来源 我们使用的是IMDB数据集,它包含来自互联网电影数据库(IMDB...
    99+
    2024-04-02
  • Pytorch深度学习gather一些使用问题解决方案
    目录问题场景描述问题的思考gather的说明问题的解决问题场景描述 我在复现Faster-RCNN模型的过程中遇到这样一个问题: 有一个张量,它的形状是 (128, 21, 4) ...
    99+
    2024-04-02
  • 理解深度学习之深度学习简介
    机器学习 在吴恩达老师的课程中,有过对机器学习的定义: ML:<P T E> P即performance,T即Task,E即Experience,机器学习是对一个Task...
    99+
    2024-04-02
  • PyTorch深度学习LSTM从input输入到Linear输出
    目录LSTM介绍LSTM参数InputsOutputsbatch_first案例LSTM介绍 关于LSTM的具体原理,可以参考: https://www.jb51.net/artic...
    99+
    2024-04-02
  • pycharm安装深度学习pytorch的d2l包失败问题解决
    目录1、首先查看现在pycharm所在的环境2、打开Anaconda Prompt3、激活现在的虚拟环境4、安装d2l包5、原因分析和心得体会,可以不看。总结pycharm里边安装不...
    99+
    2024-04-02
  • 如何解决mysql深度分页问题
    目录mysql深度分页问题1.基本分页:耗时0.019秒2.深度分页:耗时10.236秒3.深度ID分页:耗时0.052秒4.两步走深度分页:耗时0.049秒+0.017秒5.一步走深度分页:耗时0.05秒6.集成Bea...
    99+
    2023-01-09
    mysql深度分页 深度分页 mysql分页
  • Python Pytorch深度学习之图像分类器
    目录一、简介二、数据集三、训练一个图像分类器1、导入package吧2、归一化处理+贴标签吧3、先来康康训练集中的照片吧4、定义一个神经网络吧5、定义一个损失函数和优化器吧6、训练网...
    99+
    2024-04-02
  • 深度学习小工程练习之tensorflow垃圾分类详解
    介绍 这是一个基于深度学习的垃圾分类小工程,用深度残差网络构建 软件架构 使用深度残差网络resnet50作为基石,在后续添加需要的层以适应不同的分类任务 模型的训...
    99+
    2024-04-02
  • Python深度学习之FastText实现文本分类详解
    FastText是一个三层的神经网络,输入层、隐含层和输出层。 FastText的优点: 使用浅层的神经网络实现了word2vec以及文本分类功能,效果与深层网络差不多,节约资源,...
    99+
    2024-04-02
  • 深度学习Tensorflow2.8 使用 BERT 进行文本分类
    目录前言1. python 库准备2. BERT 是什么?3. 获取并处理 IMDB 数据4. 初识 TensorFlow Hub 中的 BERT 处理器和模型5. 搭建模型6. 训...
    99+
    2023-01-06
    Tensorflow BERT文本分类 Tensorflow 深度学习
  • Tensorflow深度学习使用CNN分类英文文本
    目录前言源码与数据源码数据train.py 源码及分析data_helpers.py 源码及分析text_cnn.py 源码及分析前言 Github源码地址 本文同时也是学习唐宇迪...
    99+
    2024-04-02
  • opencv深入浅出了解机器学习和深度学习
    目录机器学习kNN算法图解kNN算法用kNN算法实现手写数字识别SVM算法图解SVM算法使用SVM算法识别手写数据k均值聚类算法图解k均值聚类算法使用k均值聚类算法量化图像颜色深度学...
    99+
    2024-04-02
  • Python深度学习pytorch实现图像分类数据集
    目录读取数据集读取小批量整合所有组件目前广泛使用的图像分类数据集之一是MNIST数据集。如今,MNIST数据集更像是一个健全的检查,而不是一个基准。 为了提高难度,我们将在接下来的章...
    99+
    2024-04-02
  • Pytorch深度学习之实现病虫害图像分类
    目录一、pytorch框架1.1、概念1.2、机器学习与深度学习的区别1.3、在python中导入pytorch成功截图二、数据集三、代码复现3.1、导入第三方库3.2、CNN代码3...
    99+
    2024-04-02
  • 深度学习详解之初试机器学习
    机器学习可应用在各个方面,本篇将在系统性进入机器学习方向前,初步认识机器学习,利用线性回归预测波士顿房价; 原理简介 利用线性回归最简单的形式预测房价,只需要把它当做是一次线性函数y...
    99+
    2024-04-02
  • MySql深分页问题解决
    目录1. 问题描述2. 问题分析3. 验证测试3.1 创建两个表3.2 创建两个函数3.3 编写存储过程3.4 编写存储过程3.5 创建索引3.6 验证测试4. 解决方案4.1 使用索引覆盖+子查询优化4.2 起始位置重...
    99+
    2023-02-03
    MySql深分页
  • reactjs学习解决unknownatrule@tailwindcss问题
    目录解决unknown at rule @tailwind cssReact配置Tailwindcss问题 步骤测试总结解决unknown at rule @tailwin...
    99+
    2023-02-12
    reactjs学习 unknown at rule @tailwind css
  • Python-OpenCV深度学习的示例分析
    这篇文章将为大家详细讲解有关Python-OpenCV深度学习的示例分析,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。1. 计算机视觉中的深度学习简介深度学习推动了计算机视觉领域的深刻变革,我们首先解释深...
    99+
    2023-06-22
  • Python Pytorch深度学习之自动微分
    目录一、简介二、TENSOR三、梯度四、Example——雅克比向量积总结一、简介 antograd包是Pytorch中所有神经网络的核心。autograd为Tensor上的所有操作...
    99+
    2024-04-02
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作