返回顶部
首页 > 资讯 > 后端开发 > Python >pytorch中可视化之hook钩子
  • 236
分享到

pytorch中可视化之hook钩子

pytorchhook钩子pytorchhook 2023-03-23 11:03:25 236人浏览 独家记忆

Python 官方文档:入门教程 => 点击学习

摘要

目录一、hook1.1 什么是hook,什么情况下使用?1.2 hook在变量中的使用1.3 hook在模型中的使用:一、hook 在PyTorch中,提供了一个专用的接口使得网络在

一、hook

PyTorch中,提供了一个专用的接口使得网络在前向传播过程中能够获取到特征图,这个接口的名称非常形象,叫做hook。
可以想象这样的场景,数据通过网络向前传播,网络某一层我们预先设置了一个钩子,数据传播过后钩子上会留下数据在这一层的样子,读取钩子的信息就是这一层的特征图。
具体实现如下:

1.1 什么是hook,什么情况下使用?

首先,明确一下,为什么需要用hook,假设有这么一个函数

需要通过梯度下降法求最小值,其实现方法如下:

import torch
x = torch.tensor(3.0, requires_grad=True)
y = (x-2)
z = ((y-x) ** 2)
z.backward()
print("x.grad:",x.requires_grad,x.grad)
print("y.grad:",y.requires_grad,y.grad)
print("z.grad:",z.requires_grad,z.grad)

结果如下:

x.grad: True tensor(0.)
y.grad: True None
z.grad: True None

注意:在使用训练PyTorch训练模型时,只有叶节点(即直接指定数值的变量,而不是由其他变量计算得到的,比如网络输入)的梯度会保留,其余中间节点梯度在反向传播完成后就会自动释放以节省显存。 因此y.requires_grad的返回值为True,y.grad却为None。

可以看到上面的requires_grad方法都显示True,但是grad没有返回值。当然pytorch也提供某种方法保留非叶子节点的梯度信息。
使用 retain_grad() 方法可以保留非叶子节点的梯度,使用 retain_grad 保留的grad会占用显存,具体操作如下:

x = torch.tensor(3.0, requires_grad=True)
y = (x-2)
z = ((y-x) ** 2)
y.retain_grad()
z.retain_grad()
z.backward()
print("x.grad:",x.requires_grad,x.grad)
print("y.grad:",y.requires_grad,y.grad)
print("z.grad:",z.requires_grad,z.grad)

out:

x.grad: True tensor(0.)
y.grad: True tensor(-4.)
z.grad: True tensor(1.)

** 重申一次** 使用retain_grad方法会占用显存,如果不想要占用显存,就使用到了hook方法。

对于中间节点的变量a,可以使用a.reGISter_hook(hook_fn)对其grad进行操作。 而hook_fn是一个自定义的函数,其声明为hook_fn(grad) -> Tensor or None

1.2 hook在变量中的使用

1.2.1 hook的打印功能

# 自定义hook方法,其传入参数为grad,打印出使用钩子的节点梯度
def hook_fn(grad):
    print(grad)

x = torch.tensor(3.0, requires_grad=True)
y = (x-2)
z = ((y-x) ** 2)
y.register_hook(hook_fn)
z.register_hook(hook_fn)
print("backward前")

z.backward()
print("backward后\n")
print("x.grad:",x.requires_grad,x.grad)
print("y.grad:",y.requires_grad,y.grad)
print("z.grad:",z.requires_grad,z.grad)

out:

backward前
tensor(1.)
tensor(-4.)
backward后

x.grad: True tensor(0.)
y.grad: True None
z.grad: True None

可以看到绑定hook后,backward打印的时候打印了y和z的梯度,调用grad的时候没有保留grad值,已经释放掉内存。注意,打印出来的结果是反向传播,所以先打印z的梯度,再打印y的梯度。

1.2.2 使用hook改变grad的功能

对标记的节点,梯度加2

def hook_fn(grad):
    grad += 2
    print(grad)
    return grad

x = torch.tensor(3.0, requires_grad=True)
y = (x-2)
z = ((y-x) ** 2)
y.register_hook(hook_fn)
z.register_hook(hook_fn)
print("backward前")

z.backward()
print("backward后\n")
print("x.grad:",x.requires_grad,x.grad)
print("y.grad:",x.requires_grad,y.grad)
print("z.grad:",x.requires_grad,z.grad)

out:

backward前
tensor(3.)
tensor(-10.)
backward后

x.grad: True tensor(2.)
y.grad: True None
z.grad: True None

可以看到梯度教上面的已经发生的改变。

1.3 hook在模型中的使用:

PyTorch中使用register_forward_hook和register_backward_hook获取Module输入和输出的feature_map和grad。使用结构如下: hook_fn(module, input, output) -> Tensor or None
模型中使用hook一点要带有这三个参数module, grad_input, grad_output

1.3.1 register_forward_hook的使用

import torch.nn as nn

def hook_forward_fn(model,put,out):
    print("model:",model)
    print("input:",put)
    print("output:",out)
    
# 定义一个model
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv = nn.Conv2d(3, 1, 1)
        self.bn = nn.BatchNORM2d(1)
        #self.conv.register_forward_hook(hook_forward_fn)
        #self.bn.register_forward_hook(hook_forward_fn)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return torch.relu(x)
    
net = Net()
# 对模型中的具体某一层使用hook
net.conv.register_forward_hook(hook_forward_fn)
net.bn.register_forward_hook(hook_forward_fn)


x = torch.rand(1, 3, 2, 2, requires_grad=True)
y = net(x).mean()

注意:该方法不需要使用。backWord就能输出结果,是记录前向传播的钩子。
结果如下:

model: Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1))
input: (tensor([[[[0.4570, 0.6791],
          [0.0197, 0.5040]],

         [[0.8883, 0.1808],
          [0.6289, 0.9386]],

         [[0.8772, 0.5290],
          [0.0014, 0.3728]]]], requires_grad=True),)
output: tensor([[[[-0.4909, -0.1122],
          [-0.6301, -0.5649]]]], grad_fn=<ConvolutionBackward0>)
model: BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
input: (tensor([[[[-0.4909, -0.1122],
          [-0.6301, -0.5649]]]], grad_fn=<ConvolutionBackward0>),)
output: tensor([[[[-0.2060,  1.6790],
          [-0.8987, -0.5743]]]], grad_fn=<NativeBatchNormBackward0>)

1.3.2 register_backward_hook的使用

使用上面相同的Net模型

def hook_backward_fn(module, grad_input, grad_output):
    print(f"module: {module}")
    print(f"grad_output: {grad_output}")
    print(f"grad_input: {grad_input}")
    print("*"*20)
    
net = Net()
net.conv.register_backward_hook(hook_backward_fn)
net.bn.register_backward_hook(hook_backward_fn)
x = x = torch.rand(1, 3, 2, 2, requires_grad=True)
y = net(x).mean()
y.backward()

out:

module: BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
grad_output: (tensor([[[[0.2500, 0.2500],
          [0.0000, 0.0000]]]]),)
grad_input: (tensor([[[[ 0.6586, -0.3360],
          [-0.3009, -0.0218]]]]), tensor([0.4575]), tensor([0.5000]))
********************
module: Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1))
grad_output: (tensor([[[[ 0.6586, -0.3360],
          [-0.3009, -0.0218]]]]),)
grad_input: (tensor([[[[-0.2974,  0.1517],
          [ 0.1359,  0.0098]],

         [[ 0.0270, -0.0138],
          [-0.0123, -0.0009]],

         [[ 0.2918, -0.1489],
          [-0.1333, -0.0096]]]]), tensor([[[[0.4331]],

         [[0.1386]],

         [[0.4292]]]]), tensor([-1.4156e-07]))
********************

其结果是逆向输出各节点层的梯度信息。

1.3.3 hook中使用展示卷积层

随便画一张图,图片张这个样子:

在这里插入图片描述

使用读取图片发现是个4通道的图像,我们转成单通道并可视化

import matplotlib.pyplot as plt
import matplotlib.image as mping
img=mping.imread("./test1.png")
print(img.shape)
img = torch.tensor(img[:,:,0]).view(1,1,228,226)
plt.imshow(img[0][0])

在这里插入图片描述

接下来创建一个只有卷积层的模型

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(1,1,7),
                                  nn.ReLU()
                                 )

    def forward(self, x):
        x=self.conv(x)
        return x

使用我们的钩子hook对卷积层的输出进行可视化

def hook_forward_fn(model,put,out):
    print("inputshape:",put[0].shape) # 打印出输入图片的维度
    print("outputshape:",out[0][0].shape) # 经过卷积之后的维度
    # 可视化,因为卷积之后带有grad梯度信息,所以需要使用detach().numpy()方法,否则会报错
    plt.imshow(out[0][0].detach().numpy()) 

具体完整实现以及可视化代码如下:

import matplotlib.pyplot as plt
import matplotlib.image as mping
import numpy as np

img=mping.imread("./test1.png")
img = torch.tensor(img[:,:,0]).view(1,1,228,226)


def hook_forward_fn(model,put,out):
    print("inputshape:",put[0].shape)
    print("outputshape:",out[0][0].shape)
    plt.imshow(out[0][0].detach().numpy())
  
    

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(1,1,7),
                                  nn.ReLU()
                                 )

    def forward(self, x):
        x=self.conv(x)
        return x
    
model = Net()
model.conv.register_forward_hook(hook_forward_fn)
y=model(img)

在这里插入图片描述

 到此这篇关于pytorch中可视化之hook钩子的文章就介绍到这了,更多相关pytorch hook钩子内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

--结束END--

本文标题: pytorch中可视化之hook钩子

本文链接: https://lsjlt.com/news/200999.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • pytorch中可视化之hook钩子
    目录一、hook1.1 什么是hook,什么情况下使用?1.2 hook在变量中的使用1.3 hook在模型中的使用:一、hook 在PyTorch中,提供了一个专用的接口使得网络在...
    99+
    2023-03-23
    pytorch hook钩子 pytorch hook
  • pytorch可视化之hook钩子怎么使用
    这篇文章主要介绍了pytorch可视化之hook钩子怎么使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇pytorch可视化之hook钩子怎么使用文章都会有所收获,下面我们一起来看看吧。一、hook在PyTo...
    99+
    2023-07-05
  • python学习之路--hook(钩子原
    ** 什么是钩子 ** 之前有转一篇关于回调函数的文章http://blog.csdn.net/Mybigkid/article/details/67644490 钩子函数、注册函数、回调函数,...
    99+
    2023-01-31
    钩子 之路 python
  • Pytorch可视化之Visdom怎么用
    这篇文章主要为大家展示了“Pytorch可视化之Visdom怎么用”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“Pytorch可视化之Visdom怎么用”这篇文章吧。一、Visdom简介Visd...
    99+
    2023-06-20
  • Pytorch可视化之Visdom使用实例
    目录一、Visdom简介二、安装和运行三、可视化例子1、输出Hello World!2、显示图像3、绘制散点图4、绘制线条4.1 绘制一条直线4.2 绘制两条直线4.3 绘制正弦曲线...
    99+
    2024-04-02
  • php中关于hook(钩子)的简单理解
    假设你有一套登录注册业务。一开始很简单,老板说只需要常规的注册登录就行。 但是到了后面,接口被刷,老板然你在注册登录前加个验证码 然后没过多久,老板又说,当用户注册时,我们给用户的邮箱或者手机发一条欢迎短信或者邮件吧 还没过上多久,老...
    99+
    2023-09-18
    java 开发语言
  • 详解JavaScript中的before-after-hook钩子函数
    目录before-after-hook1.单独的钩子2.Hook collectionbefore-after-hook 最近看别人的代码,接触到一个插件,before-after-...
    99+
    2022-12-15
    JavaScript before-after-hook钩子函数 JavaScript before-after-hook JavaScript 钩子函数
  • Python中Hook钩子函数的作用是什么
    本篇文章为大家展示了Python中Hook钩子函数的作用是什么,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。1. 什么是Hook经常会听到钩子函数(hook function)这个概念,最近在看目标...
    99+
    2023-06-15
  • php中关于hook钩子函数底层理解
    假设你有一套登录注册业务。一开始很简单,老板说只需要常规的注册登录就行。 但是到了后面,接口被刷,老板然你在注册登录前加个验证码然后没过多久,老板又说,当用户注册时,我们给用户的邮箱...
    99+
    2023-01-13
    php hook钩子 php hook函数 php钩子函数
  • PyTorch中可视化工具的使用
    目录一、网络结构的可视化1.1 通过HiddenLayer可视化网络 1.2 通过PyTorchViz可视化网络 二、训练过程可视化 2.1 通过ten...
    99+
    2023-05-15
    PyTorch 可视化工具
  • visdom可视化pytorch训练过程
      在深度学习模型训练的过程中,常常需要实时监听并可视化一些数据,如损失值loss,正确率acc等。在Tensorflow中,最常使用的工具非Tensorboard莫属;在Pytorch中,也有类似的TensorboardX,但据说其在...
    99+
    2023-01-31
    过程 visdom pytorch
  • PyTorch 可视化工具TensorBoard和Visdom
    目录一、TensorBoard二、Visdom一、TensorBoard TensorBoard 一般都是作为 TensorFlow 的可视化工具,与 TensorFlow 深度集成...
    99+
    2024-04-02
  • 如何在PyTorch中进行模型的可视化
    在PyTorch中进行模型的可视化通常使用第三方库如torchviz或tensorboard。以下是如何使用这两个库进行模型可视化的...
    99+
    2024-03-14
    PyTorch
  • Pytorch如何把Tensor转化成图像可视化
    目录Pytorch把Tensor转化成图像可视化pytorch标准化的Tensor转图像问题总结Pytorch把Tensor转化成图像可视化 在调试程序的时候经常想把tensor可视...
    99+
    2022-12-14
    Pytorch Tensor Tensor转化成图像可视化 Pytorch可视化
  • Pytorch可视化的几种实现方法
    目录一,利用 tensorboardX 可视化网络结构二,利用 vistom 可视化三,利用pytorchviz可视化网络结构一,利用 tensorboardX 可视化网络结构 参...
    99+
    2024-04-02
  • 数据可视化之 tick_params(
    参考:https://blog.csdn.net/helunqu2017/article/details/78736554/ 初学数据可视化,遇到了tick_params() 里面传参数问题,找了一些资料,觉得这个简单明了,非常好用,推荐...
    99+
    2023-01-30
    数据 tick_params
  • 数据可视化之pyecharts
    pyechats是一个用于数据可视化的包。 Echats是百度开源的一个数据可视化js库,主要用于数据可视化,pyecharts 是一个用于生成Echarts图标的类库,实际上就是Echarts和Python的对接。 pyecharts...
    99+
    2023-01-30
    数据 pyecharts
  • PyTorch可视化工具TensorBoard和Visdom怎么用
    今天小编给大家分享一下PyTorch可视化工具TensorBoard和Visdom怎么用的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了...
    99+
    2023-06-26
  • pytorch 权重weight 与 梯度grad 可视化操作
    pytorch 权重weight 与 梯度grad 可视化 查看特定layer的权重以及相应的梯度信息 打印模型 观察到model下面有module的key,module下面有fe...
    99+
    2024-04-02
  • 浅谈一下基于Pytorch的可视化工具
    目录准备网络网络结构的可视化---PytorchViz训练过程可视化---TensorboardXVisdom可视化深度学习网络通常具有很深的层次结构,而且层与层之间通常会有并联、串...
    99+
    2023-05-14
    Pytorch Pytorch可视化工具
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作