返回顶部
首页 > 资讯 > 后端开发 > Python >手把手教你使用TensorFlow2实现RNN
  • 722
分享到

手把手教你使用TensorFlow2实现RNN

2024-04-02 19:04:59 722人浏览 安东尼

Python 官方文档:入门教程 => 点击学习

摘要

目录概述权重共享计算过程:案例数据集RNN 层获取数据完整代码概述 RNN (Recurrent Netural Network) 是用于处理序列数据的神经网络. 所谓序列数据, 即

概述

RNN (Recurrent Netural Network) 是用于处理序列数据的神经网络. 所谓序列数据, 即前面的输入和后面的输入有一定的联系.

在这里插入图片描述

权重共享

传统神经网络:

在这里插入图片描述

RNN:

在这里插入图片描述

RNN 的权重共享和 CNN 的权重共享类似, 不同时刻共享一个权重, 大大减少了参数数量.

计算过程:

在这里插入图片描述

计算状态 (State)

在这里插入图片描述

计算输出:

在这里插入图片描述

案例

数据集

IBIM 数据集包含了来自互联网的 50000 条关于电影的评论, 分为正面评价和负面评价.

RNN 层


class RNN(tf.keras.Model):

    def __init__(self, units):
        super(RNN, self).__init__()

        # 初始化 [b, 64] (b 表示 batch_size)
        self.state0 = [tf.zeros([batch_size, units])]
        self.state1 = [tf.zeros([batch_size, units])]

        # [b, 80] => [b, 80, 100]
        self.embedding = tf.keras.layers.Embedding(total_Words, embedding_len, input_length=max_review_len)

        self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
        self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)

        # [b, 80, 100] => [b, 64] => [b, 1]
        self.out_layer = tf.keras.layers.Dense(1)

    def call(self, inputs, training=None):
        """

        :param inputs: [b, 80]
        :param training:
        :return:
        """

        state0 = self.state0
        state1 = self.state1

        x = self.embedding(inputs)

        for word in tf.unstack(x, axis=1):
            out0, state0 = self.rnn_cell0(word, state0, training=training)
            out1, state1 = self.rnn_cell1(out0, state1, training=training)

        # [b, 64] -> [b, 1]
        x = self.out_layer(out1)

        prob = tf.sigmoid(x)

        return prob

获取数据


def get_data():
    # 获取数据
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)

    # 更改句子长度
    X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
    X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)

    # 调试输出
    print(X_train.shape, y_train.shape)  # (25000, 80) (25000,)
    print(X_test.shape, y_test.shape)  # (25000, 80) (25000,)

    # 分割训练集
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)

    # 分割测试集
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
    test_db = test_db.batch(batch_size, drop_remainder=True)

    return train_db, test_db

完整代码


import Tensorflow as tf


class RNN(tf.keras.Model):

    def __init__(self, units):
        super(RNN, self).__init__()

        # 初始化 [b, 64]
        self.state0 = [tf.zeros([batch_size, units])]
        self.state1 = [tf.zeros([batch_size, units])]

        # [b, 80] => [b, 80, 100]
        self.embedding = tf.keras.layers.Embedding(total_words, embedding_len, input_length=max_review_len)

        self.rnn_cell0 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)
        self.rnn_cell1 = tf.keras.layers.SimpleRNNCell(units=units, dropout=0.2)

        # [b, 80, 100] => [b, 64] => [b, 1]
        self.out_layer = tf.keras.layers.Dense(1)

    def call(self, inputs, training=None):
        """

        :param inputs: [b, 80]
        :param training:
        :return:
        """

        state0 = self.state0
        state1 = self.state1

        x = self.embedding(inputs)

        for word in tf.unstack(x, axis=1):
            out0, state0 = self.rnn_cell0(word, state0, training=training)
            out1, state1 = self.rnn_cell1(out0, state1, training=training)

        # [b, 64] -> [b, 1]
        x = self.out_layer(out1)

        prob = tf.sigmoid(x)

        return prob


# 超参数
total_words = 10000  # 文字数量
max_review_len = 80  # 句子长度
embedding_len = 100  # 词维度
batch_size = 1024  # 一次训练的样本数目
learning_rate = 0.0001  # 学习率
iteration_num = 20  # 迭代次数
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 优化器
loss = tf.losses.BinaryCrossentropy(from_logits=True)  # 损失
model = RNN(64)

# 调试输出summary
model.build(input_shape=[None, 64])
print(model.summary())

# 组合
model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])


def get_data():
    # 获取数据
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=total_words)

    # 更改句子长度
    X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=max_review_len)
    X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=max_review_len)

    # 调试输出
    print(X_train.shape, y_train.shape)  # (25000, 80) (25000,)
    print(X_test.shape, y_test.shape)  # (25000, 80) (25000,)

    # 分割训练集
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_db = train_db.shuffle(10000).batch(batch_size, drop_remainder=True)

    # 分割测试集
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
    test_db = test_db.batch(batch_size, drop_remainder=True)

    return train_db, test_db


if __name__ == "__main__":
    # 获取分割的数据集
    train_db, test_db = get_data()

    # 拟合
    model.fit(train_db, epochs=iteration_num, validation_data=test_db, validation_freq=1)

输出结果:

Model: "rnn"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) multiple 1000000
_________________________________________________________________
simple_rnn_cell (SimpleRNNCe multiple 10560
_________________________________________________________________
simple_rnn_cell_1 (SimpleRNN multiple 8256
_________________________________________________________________
dense (Dense) multiple 65
=================================================================
Total params: 1,018,881
Trainable params: 1,018,881
Non-trainable params: 0
_________________________________________________________________
None

(25000, 80) (25000,)
(25000, 80) (25000,)
Epoch 1/20
2021-07-10 17:59:45.150639: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (reGIStered 2)
24/24 [==============================] - 12s 294ms/step - loss: 0.7113 - accuracy: 0.5033 - val_loss: 0.6968 - val_accuracy: 0.4994
Epoch 2/20
24/24 [==============================] - 7s 292ms/step - loss: 0.6951 - accuracy: 0.5005 - val_loss: 0.6939 - val_accuracy: 0.4994
Epoch 3/20
24/24 [==============================] - 7s 297ms/step - loss: 0.6937 - accuracy: 0.5000 - val_loss: 0.6935 - val_accuracy: 0.4994
Epoch 4/20
24/24 [==============================] - 8s 316ms/step - loss: 0.6934 - accuracy: 0.5001 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 5/20
24/24 [==============================] - 7s 301ms/step - loss: 0.6934 - accuracy: 0.4996 - val_loss: 0.6933 - val_accuracy: 0.4994
Epoch 6/20
24/24 [==============================] - 8s 334ms/step - loss: 0.6932 - accuracy: 0.5000 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 7/20
24/24 [==============================] - 10s 398ms/step - loss: 0.6931 - accuracy: 0.5006 - val_loss: 0.6932 - val_accuracy: 0.4994
Epoch 8/20
24/24 [==============================] - 9s 382ms/step - loss: 0.6930 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.4994
Epoch 9/20
24/24 [==============================] - 8s 322ms/step - loss: 0.6924 - accuracy: 0.4995 - val_loss: 0.6913 - val_accuracy: 0.5240
Epoch 10/20
24/24 [==============================] - 8s 321ms/step - loss: 0.6812 - accuracy: 0.5501 - val_loss: 0.6655 - val_accuracy: 0.5767
Epoch 11/20
24/24 [==============================] - 8s 318ms/step - loss: 0.6381 - accuracy: 0.6896 - val_loss: 0.6235 - val_accuracy: 0.7399
Epoch 12/20
24/24 [==============================] - 8s 323ms/step - loss: 0.6088 - accuracy: 0.7655 - val_loss: 0.6110 - val_accuracy: 0.7533
Epoch 13/20
24/24 [==============================] - 8s 321ms/step - loss: 0.5949 - accuracy: 0.7956 - val_loss: 0.6111 - val_accuracy: 0.7878
Epoch 14/20
24/24 [==============================] - 8s 324ms/step - loss: 0.5859 - accuracy: 0.8142 - val_loss: 0.5993 - val_accuracy: 0.7904
Epoch 15/20
24/24 [==============================] - 8s 330ms/step - loss: 0.5791 - accuracy: 0.8318 - val_loss: 0.5961 - val_accuracy: 0.7907
Epoch 16/20
24/24 [==============================] - 8s 340ms/step - loss: 0.5739 - accuracy: 0.8421 - val_loss: 0.5942 - val_accuracy: 0.7961
Epoch 17/20
24/24 [==============================] - 9s 378ms/step - loss: 0.5701 - accuracy: 0.8497 - val_loss: 0.5933 - val_accuracy: 0.8014
Epoch 18/20
24/24 [==============================] - 9s 361ms/step - loss: 0.5665 - accuracy: 0.8589 - val_loss: 0.5958 - val_accuracy: 0.8082
Epoch 19/20
24/24 [==============================] - 8s 353ms/step - loss: 0.5630 - accuracy: 0.8681 - val_loss: 0.5931 - val_accuracy: 0.7966
Epoch 20/20
24/24 [==============================] - 8s 314ms/step - loss: 0.5614 - accuracy: 0.8702 - val_loss: 0.5925 - val_accuracy: 0.7959

Process finished with exit code 0

到此这篇关于手把手教你使用TensorFlow2实现RNN的文章就介绍到这了,更多相关TensorFlow2实现RNN内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

--结束END--

本文标题: 手把手教你使用TensorFlow2实现RNN

本文链接: https://lsjlt.com/news/130536.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • 手把手教你使用TensorFlow2实现RNN
    目录概述权重共享计算过程:案例数据集RNN 层获取数据完整代码概述 RNN (Recurrent Netural Network) 是用于处理序列数据的神经网络. 所谓序列数据, 即...
    99+
    2024-04-02
  • 手把手教你使用redis实现排行榜功能
    目录一、需求背景二、实现思路  1、利用数据库2、利用Redis总结一、需求背景 最近项目需要做排行榜功能,实现员工邀请用户注册排行榜,要求是实时更新,查询要快。员工所属支行、二级行、省行,界面要根据条件显示排...
    99+
    2023-04-14
    redis 实现排行榜 redis做排行榜 redis 排行榜
  • 用Python手把手教你实现2048小游戏
    目录一、开发环境二、环境搭建三、原理介绍四、效果图一、开发环境 Python版本:3.6.4 相关模块: pygame模块; 以及一些Python自带的模块。 二、环境搭建 安装Python并添加到环境变量,pip安...
    99+
    2022-06-02
    Python 2048小游戏 python游戏
  • 手把手教你在vue中使用three.js
    目录在vue中使用three.js1.首先安装three.js、引入2.在页面内写入three.js 总结在vue中使用three.js 下面我会介绍three.js的基础...
    99+
    2023-03-01
    vue使用threejs threejs教程 three.js案例
  • 手把手教你在Python里使用ChatGPT
    目录前言知识点实现代码后话前言 近来chatGPT挺火的,也试玩了一下,确实挺有意思。这里记录一下在Python中如何去使用chatGPT。 本篇文章的实现100%基于 chatGP...
    99+
    2022-12-19
    python使用chatgpt 如何使用chatgpt chatgpt怎么玩
  • 【Linux】手把手教你实现udp服务器
    网络套接字~ 文章目录 前言一、udp服务器的实现总结 前言 上一篇文章中我们讲到了很多的网络名词以及相关知识,下面我们就直接进入udp服务器的实现。 一、udp服务器的实现 首先我们需要创建五个文件(文件名可以自己命...
    99+
    2023-08-31
    c++ 后端 linux udp 服务器 网络协议 运维
  • 手把手教你vue实现动态路由
    目录1、什么是动态路由?2、动态路由的好处3、动态路由如何实现总结1、什么是动态路由? 动态路由,动态即不是写死的,是可变的。我们可以根据自己不同的需求加载不同的路由,做到不同的实现...
    99+
    2024-04-02
  • 手把手教你Vue3实现路由跳转
    目录一、安装 vue-router二、新建 vue 页面2.1 login.vue2.2 register.vue三、新建路由文件3.1 新建 index.js3.2 新建 rout...
    99+
    2024-04-02
  • 手把手教你用C语言实现三子棋
    目录1.设计简单菜单2.创建棋盘3.下棋过程的实现 3.1玩家下棋 3.2电脑下棋3.3判断输赢4.游戏源码总结1.设计简单菜单 相信大家在玩游戏时会发现,进入游...
    99+
    2024-04-02
  • 手把手教你用Matplotlib实现数据可视化
    目录介绍简单图形绘制快速上手自定义X/Y轴图表实现汇总正弦曲线图柱状图散点图饼图量场图等高线图图形样式折线图散点图饼图组合图形样式图形位置figure对象subplots对象规范绘图...
    99+
    2024-04-02
  • 手把手教你用js实现瀑布流布局
    它可以有效的降低页面的复杂度,节省很多的空间;并且,瀑布流的参差不齐的排列方式,可以通过界面展示给用户多条数据,并且让用户可以有向下浏览的冲动,提供了很好的用户体验!例如淘宝的页面就采用了这种布局方式,给大家看看淘宝的瀑布流布局的效果图(手...
    99+
    2023-05-14
    JavaScript
  • 手把手教你用Javascript实现观察者模式
    目录什么是观察者模式?场景模拟代码实现重构代码总结 什么是观察者模式? 观察者模式一种设计模式。 观察者模式定义了对象间的一种 一对多 的依赖关系,当一个对象的状态...
    99+
    2024-04-02
  • 手把手教你使用Java实现在线生成pdf文档
    目录一、介绍二、案例实现2.1添加 iText 依赖包2.2简单实现2.3复杂实现2.4变量替换方式三、总结一、介绍 在实际的业务开发的时候,研发人员往往会碰到很多这样的一些场景,需...
    99+
    2024-04-02
  • 手把手教你如何使用 GitHub 平台
    作为全球最大的开源代码托管平台,GitHub 提供了无数的项目与资源,也方便了开发者们相互之间的交流与学习。本文将为您介绍如何使用 GitHub 平台。1. 注册 GitHub 账户在使用 GitHub 之前,您首先需要注册一个 GitHu...
    99+
    2023-10-22
  • 一文看懂:手把手教你使用ChatGPT
    自去年年底推出以来,这款新的人工智慧语言处理工具一直处于潮流的顶端,那么你该如何使用它呢?人工智慧聊天机器人ChatGPT 的大火掀起了国内外对相关概念公司的投资热潮。那么,ChatGPT 为什么这么火,为什么说它可能颠覆传统的搜寻引擎,普...
    99+
    2023-02-09
    ChatGPT
  • 手把手教会你使用redux的入门教程
    目录Redux详解Redux介绍Redux有什么作用如何在React中使用Redux如何使用React-reduxRedux详解 Redux介绍 Redux 是 JavaScript...
    99+
    2024-04-02
  • 手把手教你docker部署(使用docker-compose)教程
    目录一、docker一些基础命令二、docker部署(使用docker-compose)2.1 安装docker(服务器:CentOS 7或更高版本)2.2 安装docker-com...
    99+
    2023-01-28
    docker部署使用实例 dockercompose docker 部署
  • 手把手教你实现PyTorch的MNIST数据集
    目录概述 获取数据 网络模型 train 函数 test 函数 main 函数 完整代码:概述 MNIST 包含 0~9 的手写数字, 共有 60000 个训练集和 10000 个...
    99+
    2024-04-02
  • 手把手教你实现Android编译期注解
    详细阐述了实现一个Android编译期注解sdk的步骤以及注意事项,并简要分析了运行时注解以及字节码技术在生成代码上与编译期注解的不同与优劣 一、编译期注解在开发中的重要性 从早期令...
    99+
    2024-04-02
  • 手把手教你实现一个 Python 计时器
    为了更好地掌握 Python 计时器的应用,我们后面还补充了有关Python类、上下文管理器和装饰器的背景知识。因篇幅限制,其中利用上下文管理器和装饰器优化 Python 计时器,将在后续文章学习,不在本篇文章范围内。Python 计时器首...
    99+
    2023-05-14
    Python 编程语言 计时器
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作