返回顶部
首页 > 资讯 > 后端开发 > Python >Tensorflow 如何从checkpoint文件中加载变量名和变量值
  • 168
分享到

Tensorflow 如何从checkpoint文件中加载变量名和变量值

2024-04-02 19:04:59 168人浏览 泡泡鱼

Python 官方文档:入门教程 => 点击学习

摘要

假设你已经经过上千次的迭代,并且得到了以下模型: 则从这些checkpoint文件中加载变量名和变量值代码如下: model_dir = './ckpt-182802' imp

假设你已经经过上千次的迭代,并且得到了以下模型:

在这里插入图片描述

则从这些checkpoint文件中加载变量名和变量值代码如下:


model_dir = './ckpt-182802'
import Tensorflow as tf
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader(model_dir)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
     print("tensor_name: ", key)
     print(reader.get_tensor(key)) # Remove this is you want to print only variable names

Mnist

下面将给出一个基于卷积神经网络的手写数字识别样例:


# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.Python.framework import graph_util
log_dir = './tensorboard'
mnist = input_data.read_data_sets(train_dir="./mnist_data",one_hot=True)
if tf.gfile.Exists(log_dir):
        tf.gfile.DeleteRecursively(log_dir)
tf.gfile.MakeDirs(log_dir)

#定义输入数据mnist图片大小28*28*1=784,None表示batch_size
x = tf.placeholder(dtype=tf.float32,shape=[None,28*28],name="input")
#定义标签数据,mnist共10类
y_ = tf.placeholder(dtype=tf.float32,shape=[None,10],name="y_")
#将数据调整为二维数据,w*H*c---> 28*28*1,-1表示N张
image = tf.reshape(x,shape=[-1,28,28,1])

#第一层,卷积核={5*5*1*32},池化核={2*2*1,1*2*2*1}
w1 = tf.Variable(initial_value=tf.random_nORMal(shape=[5,5,1,32],stddev=0.1,dtype=tf.float32,name="w1"))
b1= tf.Variable(initial_value=tf.zeros(shape=[32]))
conv1 = tf.nn.conv2d(input=image,filter=w1,strides=[1,1,1,1],padding="SAME",name="conv1")
relu1 = tf.nn.relu(tf.nn.bias_add(conv1,b1),name="relu1")
pool1 = tf.nn.max_pool(value=relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
#shape={None,14,14,32}
#第二层,卷积核={5*5*32*64},池化核={2*2*1,1*2*2*1}
w2 = tf.Variable(initial_value=tf.random_normal(shape=[5,5,32,64],stddev=0.1,dtype=tf.float32,name="w2"))
b2 = tf.Variable(initial_value=tf.zeros(shape=[64]))
conv2 = tf.nn.conv2d(input=pool1,filter=w2,strides=[1,1,1,1],padding="SAME")
relu2 = tf.nn.relu(tf.nn.bias_add(conv2,b2),name="relu2")
pool2 = tf.nn.max_pool(value=relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME",name="pool2")
#shape={None,7,7,64}
#FC1
w3 = tf.Variable(initial_value=tf.random_normal(shape=[7*7*64,1024],stddev=0.1,dtype=tf.float32,name="w3"))
b3 = tf.Variable(initial_value=tf.zeros(shape=[1024]))
#关键,进行reshape
input3 = tf.reshape(pool2,shape=[-1,7*7*64],name="input3")
fc1 = tf.nn.relu(tf.nn.bias_add(value=tf.matmul(input3,w3),bias=b3),name="fc1")
#shape={None,1024}
#FC2
w4 = tf.Variable(initial_value=tf.random_normal(shape=[1024,10],stddev=0.1,dtype=tf.float32,name="w4"))
b4 = tf.Variable(initial_value=tf.zeros(shape=[10]))
fc2 = tf.nn.bias_add(value=tf.matmul(fc1,w4),bias=b4,name="logit")
#shape={None,10}
#定义交叉熵损失
# 使用softmax将NN计算输出值表示为概率
y = tf.nn.softmax(fc2,name="out")

# 定义交叉熵损失函数
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=fc2,labels=y_)
loss = tf.reduce_mean(cross_entropy)
tf.summary.Scalar('Cross_Entropy',loss)
#定义solver
train = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss)
for var in tf.trainable_variables():
	print var
#train = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss)

#定义正确值,判断二者下标index是否相等
correct_predict = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
#定义如何计算准确率
accuracy = tf.reduce_mean(tf.cast(correct_predict,dtype=tf.float32),name="accuracy")
tf.summary.scalar('Training_ACC',accuracy)
#定义初始化op
merged = tf.summary.merge_all()
init = tf.global_variables_initializer()
saver = tf.train.Saver()
#训练NN
with tf.Session() as session:
    session.run(fetches=init)
    writer = tf.summary.FileWriter(log_dir,session.graph) #定义记录日志的位置
    for i in range(0,500):
        xs, ys = mnist.train.next_batch(100)
        session.run(fetches=train,feed_dict={x:xs,y_:ys})
        if i%10 == 0:
            train_accuracy,summary = session.run(fetches=[accuracy,merged],feed_dict={x:xs,y_:ys})
            writer.add_summary(summary,i)
            print(i,"accuracy=",train_accuracy)
    '''
    #训练完成后,将网络中的权值转化为常量,形成常量graph,注意:需要x与label
    constant_graph = graph_util.convert_variables_to_constants(sess=session,
                                                            input_graph_def=session.graph_def,
                                                            output_node_names=['out','y_','input'])
    #将带权值的graph序列化,写成pb文件存储起来
    with tf.gfile.FastGFile("lenet.pb", mode='wb') as f:
        f.write(constant_graph.SerializeToString())
    '''
    saver.save(session,'./ckpt')

补充:查看tensorflow产生的checkpoint文件内容的方法

tensorflow在保存权重模型时多使用tf.train.Saver().save 函数进行权重保存,保存的ckpt文件无法直接打开,但tensorflow提供了相关函数 tf.train.NewCheckpointReader 可以对ckpt文件进行权重查看。


import os
from tensorflow.python import pywrap_tensorflow

checkpoint_path = os.path.join('modelckpt', "fc_nn_model")
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and values
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))

其中‘modelckpt'是存放.ckpt文件的文件夹,"fc_nn_model"是文件名,如下图所示。

在这里插入图片描述 

var_to_shape_map是一个字典,其中的键值是变量名,对应的值是该变量的形状,如


{‘LSTM_input/bias_LSTM/Adam_1': [128]}

想要查看某变量值时,需要调用get_tensor函数,即输入以下代码:


reader.get_tensor('LSTM_input/bias_LSTM/Adam_1')

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

--结束END--

本文标题: Tensorflow 如何从checkpoint文件中加载变量名和变量值

本文链接: https://lsjlt.com/news/126644.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • Tensorflow 如何从checkpoint文件中加载变量名和变量值
    假设你已经经过上千次的迭代,并且得到了以下模型: 则从这些checkpoint文件中加载变量名和变量值代码如下: model_dir = './ckpt-182802' imp...
    99+
    2024-04-02
  • tensorflow如何保存变量到文件
    在TensorFlow中,可以使用tf.train.Saver()来保存变量到文件中。以下是一个保存变量的示例代码: import ...
    99+
    2024-04-03
    tensorflow
  • 如何用.env文件为NodeJS加载环境变量
    这篇文章主要讲解了“如何用.env文件为NodeJS加载环境变量”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“如何用.env文件为NodeJS加载环境变量”...
    99+
    2024-04-02
  • 如何在PHP中使用常量和变量可变变量
    如何在PHP中使用常量和变量可变变量?相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。关于可变变量:以声明的变量前,再加上变量符;运用代码举例说明,如下:<php$china...
    99+
    2023-06-15
  • Python中如何使用变量创建文件名
    本篇内容介绍了“Python中如何使用变量创建文件名”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!使用格式化的字符串文字来使用变量创建文件名...
    99+
    2023-07-05
  • Python中如何使用中文变量名
    这期内容当中小编将会给大家带来有关Python中如何使用中文变量名,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。Python3.x 已经支持全面 Unicode 编码,比如支持使用中文作为变量名。>...
    99+
    2023-06-15
  • 如何解析Python 变量命名规则和定义变量
    这篇文章给大家介绍如何解析Python 变量命名规则和定义变量,内容非常详细,感兴趣的小伙伴们可以参考借鉴,希望对大家能有所帮助。一、定义变量语法规则:变量名 = 值定义变量的语法规则中间的‘=',并不...
    99+
    2023-06-22
  • vue中如何将变量赋值
    在vue中给变量赋值的方法:1.新建common.vue文件,并定义变量;2.创建vue.js项目;3.使用import方法导入变量;4.执行代码赋值使用变量;具体步骤如下:首先,新建一个common.vue文件,并在文件中定义一个全局变量...
    99+
    2024-04-02
  • pytho中n变量如何赋值
    这期内容当中小编将会给大家带来有关pytho中n变量如何赋值,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。python的五大特点是什么python的五大特点:1.简单易学,开发程序时,专注的是解决问题,而...
    99+
    2023-06-14
  • java中如何给变量赋值
    在Java中给变量赋值有两种方式:1. 直接赋值:通过使用等号(=)将一个值赋给变量。例如:`int num = 10;` (将10...
    99+
    2023-08-17
    java
  • php不使用中间变量如何互换两变量的值
    这篇文章主要介绍“php不使用中间变量如何互换两变量的值”,在日常操作中,相信很多人在php不使用中间变量如何互换两变量的值问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”php不使用中间变量如何互换两变量的值...
    99+
    2023-07-05
  • PHP变量值如何按秒自动增加
    本文小编为大家详细介绍“PHP变量值如何按秒自动增加”,内容详细,步骤清晰,细节处理妥当,希望这篇“PHP变量值如何按秒自动增加”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。在PHP中,变量值按秒自动增加是一项非...
    99+
    2023-07-05
  • php如何为变量增加一个键值
    本篇内容主要讲解“php如何为变量增加一个键值”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“php如何为变量增加一个键值”吧!我们可以通过使用PHP内置的函数array_push()来实现在数组...
    99+
    2023-07-05
  • 在 Python 中使用变量创建文件名
    使用格式化的字符串文字来使用变量创建文件名,例如 f'{variable}.txt'。 格式化的字符串文字使我们能够通过在字符串前面加上 f 来在字符串中包含表达式和变量。 file_name = '...
    99+
    2023-09-22
    python
  • PHP中变量如何进行传值
    这篇文章给大家分享的是有关PHP中变量如何进行传值的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。变量的传值方式:1,变量的传值方式,是指“一个变量,传给另-个变量”的内部细节形式- -单对单;2,变量的传值方式,...
    99+
    2023-06-20
  • log4j如何根据变量动态生成文件名
    目录根据变量动态生成文件名简单的log4j设置简单实例log4j动态文件名一、按照用户ID来生成log二、在batch程序中,通过一个设定来实现每个batch三、在batch程序中,...
    99+
    2024-04-02
  • 如何在Shell中计算变量数值
    如何在Shell中计算变量数值?相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。算术运算符如果要执行算术运算符,就离不开各种运算符号,和其他编程语言类似,shell也有很多算术运算...
    99+
    2023-06-09
  • dotenv怎么从.env文件中读取环境变量
    这篇文章主要讲解了“dotenv怎么从.env文件中读取环境变量”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“dotenv怎么从.env文件中读取环境变量”吧!引言dotenv从.env文件...
    99+
    2023-07-04
  • PHP如何将文件名存储到变量并计算文件中的行数
    这篇文章主要为大家展示了“PHP如何将文件名存储到变量并计算文件中的行数”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“PHP如何将文件名存储到变量并计算文件中的行数”这篇文章吧。首先我们创建一个...
    99+
    2023-06-20
  • python中如何对多变量连续赋值
    看到一段代码,如下 self.batch_size = batch_size = 128 初一看很诧异,仔细想想其实很合理的。 在python可能会需要同时声明多个变量,并对多个变量赋予相同的初始值,可以采用如...
    99+
    2022-06-02
    python 多变量 连续赋值
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作