返回顶部
首页 > 资讯 > 后端开发 > Python >python中TensorFlow神经网络模型的保存和读取方法是什么
  • 337
分享到

python中TensorFlow神经网络模型的保存和读取方法是什么

2023-06-25 12:06:49 337人浏览 泡泡鱼

Python 官方文档:入门教程 => 点击学习

摘要

本篇内容主要讲解“python中Tensorflow神经网络模型的保存和读取方法是什么”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“Python中TensorFlow神经网络模型的保存和读取方法

本篇内容主要讲解“pythonTensorflow神经网络模型的保存和读取方法是什么”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习Python中TensorFlow神经网络模型的保存和读取方法是什么”吧!

TensorFlow提供了一个非常简单的api,即tf.train.Saver类来保存和还原一个神经网络模型。

下面代码给出了保存TensorFlow模型的方法:

import tensorflow as tf# 声明两个变量v1 = tf.Variable(tf.random_nORMal([1, 2]), name="v1")v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")init_op = tf.global_variables_initializer() # 初始化全部变量saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 声明tf.train.Saver类用于保存模型with tf.Session() as sess:    sess.run(init_op)    print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比    print("v2:", sess.run(v2))    saver_path = saver.save(sess, "save/model.ckpt")  # 将模型保存到save/model.ckpt文件    print("Model saved in file:", saver_path)

注:Saver方法已经发生了更改,现在是V2版本,tf.train.Saver(write_version=tf.train.SaverDef.V1)括号里加入该参数可继续使用V1,但会报warning,可忽略。若使用saver = tf.train.Saver()则默认使用当前的版本(V2),保存后在save这个文件夹中会出现4个文件,比V1版多出model.ckpt.data-00000-of-00001这个文件,这点感谢评论里那位朋友指出。至于这个文件的含义到目前我仍不是很清楚,也没查到具体资料,TensorFlow15年底开源到现在很多类啊函数都一直发生着变动,或被更新或被弃用,可能一些代码在当时是没问题的,但过了一大段时间后再跑可能就会报错,在此注明事件时间:2017.4.30

这段代码中,通过saver.save函数将TensorFlow模型保存到了save/model.ckpt文件中,这里代码中指定路径为"save/model.ckpt",也就是保存到了当前程序所在文件夹里面的save文件夹中。

TensorFlow模型会保存在后缀为.ckpt的文件中。保存后在save这个文件夹中会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。

checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.

model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。

model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice Protocol Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,自查。

python中TensorFlow神经网络模型的保存和读取方法是什么

下面代码给出了加载TensorFlow模型的方法:

可以对比一下v1、v2的值是随机初始化的值还是和之前保存的值是一样的?

import tensorflow as tf# 使用和保存模型代码中一样的方式来声明变量v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型with tf.Session() as sess:    saver.restore(sess, "save/model.ckpt") # 即将固化到硬盘中的Session从保存路径再读取出来    print("v1:", sess.run(v1)) # 打印v1、v2的值和之前的进行对比    print("v2:", sess.run(v2))    print("Model Restored")

运行结果:

v1: [[ 0.76705766  1.82217288]]v2: [[-0.98012197  1.2369734   0.5797025 ] [ 2.50458145  0.81897354  0.07858191]]Model Restored

这段加载模型的代码基本上和保存模型的代码是一样的。也是先定义了TensorFlow计算图上所有的运算,并声明了一个tf.train.Saver类。两段唯一的不同是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。
也就是说使用TensorFlow完成了一次模型的保存和读取的操作。

如果不希望重复定义图上的运算,也可以直接加载已经持久化的图:

import tensorflow as tf# 在下面的代码中,默认加载了TensorFlow计算图上定义的全部变量# 直接加载持久化的图saver = tf.train.import_meta_graph("save/model.ckpt.meta")with tf.Session() as sess:    saver.restore(sess, "save/model.ckpt")    # 通过张量的名称来获取张量    print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))

运行程序,输出:

[[ 0.76705766  1.82217288]]

有时可能只需要保存或者加载部分变量。
比如,可能有一个之前训练好的5层神经网络模型,但现在想写一个6层的神经网络,那么可以将之前5层神经网络中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。

为了保存或者加载部分变量,在声明tf.train.Saver类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用saver = tf.train.Saver([v1])命令来构建tf.train.Saver类,那么只有变量v1会被加载进来。

到此,相信大家对“python中TensorFlow神经网络模型的保存和读取方法是什么”有了更深的了解,不妨来实际操作一番吧!这里是编程网网站,更多相关内容可以进入相关频道进行查询,关注我们,继续学习!

--结束END--

本文标题: python中TensorFlow神经网络模型的保存和读取方法是什么

本文链接: https://lsjlt.com/news/304803.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • python中TensorFlow神经网络模型的保存和读取方法是什么
    本篇内容主要讲解“python中TensorFlow神经网络模型的保存和读取方法是什么”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“python中TensorFlow神经网络模型的保存和读取方法...
    99+
    2023-06-25
  • python深度学习TensorFlow神经网络模型的保存和读取
    之前的笔记里实现了softmax回归分类、简单的含有一个隐层的神经网络、卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了...
    99+
    2024-04-02
  • python神经网络使用Keras进行模型的保存与读取
    目录学习前言Keras中保存与读取的重要函数1、model.save2、load_model全部代码学习前言 开始做项目的话,有些时候会用到别人训练好的模型,这个时候要学会load噢...
    99+
    2024-04-02
  • python字典保存和读取的方法是什么
    在Python中,可以使用pickle模块来保存和读取字典。保存字典到文件:```import picklemy_dict = {'...
    99+
    2023-08-08
    python
  • TensorFlow卷积神经网络MNIST数据集实现方法是什么
    本篇内容主要讲解“TensorFlow卷积神经网络MNIST数据集实现方法是什么”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“TensorFlow卷积神经网络MNIST数据集实现方法是什么”吧!...
    99+
    2023-06-25
  • mfc文件读取和保存的方法是什么
    MFC(Microsoft Foundation Classes)是微软提供的一套面向对象的程序库,用于开发Windows应用程序。...
    99+
    2023-10-10
    mfc
  • TensorFlow神经网络创建多层感知机MNIST数据集的方法是什么
    这篇文章主要讲解了“TensorFlow神经网络创建多层感知机MNIST数据集的方法是什么”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“TensorFlow神经网络创建多层感知机MNIST数...
    99+
    2023-06-25
  • python读取内存的方法是什么
    Python读取内存的方法可以通过使用内置的`memoryview`对象或`ctypes`模块来实现。1. 使用`memoryvie...
    99+
    2023-08-20
    python
  • python抓取网页内容并保存的方法是什么
    在Python中,可以使用requests库来抓取网页内容,并使用文件操作来保存抓取到的内容。下面是一个示例代码: import r...
    99+
    2024-03-04
    python
  • API模型的保存与加载方法是什么
    本篇内容介绍了“API模型的保存与加载方法是什么”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!  1.目的:  将训练好的模型保存下来,已备...
    99+
    2023-06-02
  • Python中yaml文件的读取方法是什么
    这篇文章主要介绍了Python中yaml文件的读取方法是什么的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇Python中yaml文件的读取方法是什么文章都会有所收获,下面我们一起来看看吧。yaml 文件的应用场...
    99+
    2023-06-29
  • Java之Spring简单读取和存储对象的方法是什么
    这篇“Java之Spring简单读取和存储对象的方法是什么”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“Java之Sprin...
    99+
    2023-07-05
  • java中session存值和取值的方法是什么
    在Java中,可以使用HttpSession对象来存储和获取会话数据。1. 存储会话数据:```javaHttpSession se...
    99+
    2023-08-14
    java session
  • js中session存值和取值的方法是什么
    在JavaScript中,无法直接使用session来存储和获取值。但是可以使用其他方法来模拟会话存储和获取值的功能。一种常用的方法...
    99+
    2023-08-18
    js session
  • Python中Pandas文件操作和读取CSV参数的方法是什么
    这篇文章主要介绍“Python中Pandas文件操作和读取CSV参数的方法是什么”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“Python中Pandas文件操作和读取CSV参数的方法是什么”文章能帮...
    99+
    2023-07-05
  • python中csv文件读取与写入的方法是什么
    在Python中,我们可以使用`csv`模块来读取和写入CSV文件。下面是使用`csv`模块读取和写入CSV文件的方法:1. 读取CSV文件:```pythonimport csvwith open('file.csv', 'r') a...
    99+
    2023-08-11
    python
  • 什么是 JDBC Blob 数据类型如何存储和读取其中的数据
    JDBC Blob(Binary Large Object)是一种用于存储大型二进制数据的数据类型,比如图片、音频、视频等。在数据库...
    99+
    2023-10-10
    JDBC
  • 什么是 JDBC Blob 数据类型?如何存储和读取其中的数据?
    BLOB 是二进制大对象,可以容纳可变数量的数据,最大长度为 65535 个字符。它们用于存储大量二进制数据,例如图像或其他类型的数据。文件。定义为 TEXT 的字段也保存大量数据。两者之间的区别在于,存储数据的排序和比较在 BLOB 中区...
    99+
    2023-10-22
  • python中numpy数组的csv文件写入与读取方法是什么
    这篇文章主要讲解了“python中numpy数组的csv文件写入与读取方法是什么”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“python中numpy数组的csv文件写入与读取方法是什么”吧...
    99+
    2023-07-05
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作