返回顶部
首页 > 资讯 > 精选 >pytorch中的hook机制是什么
  • 558
分享到

pytorch中的hook机制是什么

2023-06-29 11:06:18 558人浏览 安东尼
摘要

本篇内容介绍了“PyTorch中的hook机制是什么”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!1、hook背景Hook被成为钩子机制,这

本篇内容介绍了“PyTorch中的hook机制是什么”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!

    1、hook背景

    Hook被成为钩子机制,这不是pytorch的首创,在windows编程中已经被普遍采用,包括进程内钩子和全局钩子。按照自己的理解,hook的作用是通过系统来维护一个链表,使得用户拦截(获取)通信消息,用于处理事件。

    pytorch中包含forwardbackward两个钩子注册函数,用于获取forward和backward中输入和输出,按照自己不全面的理解,应该目的是“不改变网络的定义代码,也不需要在forward函数中return某个感兴趣层的输出,这样代码太冗杂了”。

    2、源码阅读

    reGISter_forward_hook()函数必须在forward()函数调用之前被使用,因为这个函数源码注释显示这个函数“ it will not have effect on forward since this is called after :func:`forward` is called”,也就是这个函数在forward()之后就没有作用了!!!):

    作用:获取forward过程中每层的输入和输出,用于对比hook是不是正确记录。

    def register_forward_hook(self, hook):        r"""Registers a forward hook on the module.        The hook will be called every time after :func:`forward` has computed an output.        It should have the following signature::            hook(module, input, output) -> None or modified output        The hook can modify the output. It can modify the input inplace but        it will not have effect on forward since this is called after        :func:`forward` is called.        Returns:            :class:`torch.utils.hooks.RemovableHandle`:                a handle that can be used to remove the added hook by calling                ``handle.remove()``        """        handle = hooks.RemovableHandle(self._forward_hooks)        self._forward_hooks[handle.id] = hook        return handle

    3、定义一个用于测试hooker的类

    如果随机的初始化每个层,那么就无法测试出自己获取的输入输出是不是forward中的输入输出了,所以需要将每一层的权重和偏置设置为可识别的值(比如全部初始化为1)。网络包含两层(Linear有需要求导的参数被称为一个层,而ReLU没有需要求导的参数不被称作一层),__init__()中调用initialize函数对所有层进行初始化。

    注意:在forward()函数返回各个层的输出,但是ReLU6没有返回,因为后续测试的时候不对这一层进行注册hook。

    class TestForHook(nn.Module):    def __init__(self):        super().__init__()        self.linear_1 = nn.Linear(in_features=2, out_features=2)        self.linear_2 = nn.Linear(in_features=2, out_features=1)        self.relu = nn.ReLU()        self.relu6 = nn.ReLU6()        self.initialize()    def forward(self, x):        linear_1 = self.linear_1(x)        linear_2 = self.linear_2(linear_1)        relu = self.relu(linear_2)        relu_6 = self.relu6(relu)        layers_in = (x, linear_1, linear_2)        layers_out = (linear_1, linear_2, relu)        return relu_6, layers_in, layers_out    def initialize(self):        """ 定义特殊的初始化,用于验证是不是获取了权重"""        self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))        self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))        self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))        self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))        return True

    4、定义hook函数

    hook()函数是register_forward_hook()函数必须提供的参数,好处是“用户可以自行决定拦截了中间信息之后要做什么!”,比如自己想单纯的记录网络的输入输出(也可以进行修改等更加复杂的操作)。

    首先定义几个容器用于记录:

    定义用于获取网络各层输入输出tensor的容器:

    # 并定义module_name用于记录相应的module名字module_name = []features_in_hook = []features_out_hook = []hook函数需要三个参数,这三个参数是系统传给hook函数的,自己不能修改这三个参数:

    hook函数负责将获取的输入输出添加到feature列表中;并提供相应的module名字

    def hook(module, fea_in, fea_out):    print("hooker working")    module_name.append(module.__class__)    features_in_hook.append(fea_in)    features_out_hook.append(fea_out)    return None

    5、对需要的层注册hook

    注册钩子必须在forward()函数被执行之前,也就是定义网络进行计算之前就要注册,下面的代码对网络除去ReLU6以外的层都进行了注册(也可以选定某些层进行注册):

    注册钩子可以对某些层单独进行:

    net = TestForHook()net_chilren = net.children()for child in net_chilren:    if not isinstance(child, nn.ReLU6):        child.register_forward_hook(hook=hook)

    6、测试forward()返回的特征和hook记录的是否一致

    6.1 测试forward()提供的输入输出特征

    由于前面的forward()函数返回了需要记录的特征,这里可以直接测试:

    out, features_in_forward, features_out_forward = net(x)print("*"*5+"forward return features"+"*"*5)print(features_in_forward)print(features_out_forward)print("*"*5+"forward return features"+"*"*5)

    得到下面的输出是理所当然的:

    *****forward return features*****
    (tensor([[0.1000, 0.1000],
            [0.1000, 0.1000]]), tensor([[1.2000, 1.2000],
            [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
            [3.4000]], grad_fn=<AddmmBackward>))
    (tensor([[1.2000, 1.2000],
            [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
            [3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
            [3.4000]], grad_fn=<ThresholdBackward0>))
    *****forward return features*****

    6.2 hook记录的输入特征和输出特征

    hook通过list结构进行记录,所以可以直接print

    测试features_in是不是存储了输入:

    print("*"*5+"hook record features"+"*"*5)print(features_in_hook)print(features_out_hook)print(module_name)print("*"*5+"hook record features"+"*"*5)

    得到和forward一样的结果:

    *****hook record features*****
    [(tensor([[0.1000, 0.1000],
            [0.1000, 0.1000]]),), (tensor([[1.2000, 1.2000],
            [1.2000, 1.2000]], grad_fn=<AddmmBackward>),), (tensor([[3.4000],
            [3.4000]], grad_fn=<AddmmBackward>),)]
    [tensor([[1.2000, 1.2000],
            [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
            [3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
            [3.4000]], grad_fn=<ThresholdBackward0>)]
    [<class 'torch.nn.modules.linear.Linear'>, 
    <class 'torch.nn.modules.linear.Linear'>,
     <class 'torch.nn.modules.activation.ReLU'>]
    *****hook record features*****

    6.3 把hook记录的和forward做减法

    如果害怕会有小数点后面的数值不一致,或者数据类型的不匹配,可以对hook记录的特征和forward记录的特征做减法:

    测试forward返回的feautes_in是不是和hook记录的一致:

    print("sub result'")for forward_return, hook_record in zip(features_in_forward, features_in_hook):    print(forward_return-hook_record[0])

    得到的全部都是0,说明hook没问题:

    sub resulttensor([[0., 0.],        [0., 0.]])tensor([[0., 0.],        [0., 0.]], grad_fn=<SubBackward0>)tensor([[0.],        [0.]], grad_fn=<SubBackward0>)

    7、完整代码

    import torchimport torch.nn as nnclass TestForHook(nn.Module):    def __init__(self):        super().__init__()        self.linear_1 = nn.Linear(in_features=2, out_features=2)        self.linear_2 = nn.Linear(in_features=2, out_features=1)        self.relu = nn.ReLU()        self.relu6 = nn.ReLU6()        self.initialize()    def forward(self, x):        linear_1 = self.linear_1(x)        linear_2 = self.linear_2(linear_1)        relu = self.relu(linear_2)        relu_6 = self.relu6(relu)        layers_in = (x, linear_1, linear_2)        layers_out = (linear_1, linear_2, relu)        return relu_6, layers_in, layers_out    def initialize(self):        """ 定义特殊的初始化,用于验证是不是获取了权重"""        self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))        self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))        self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))        self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))        return True

    定义用于获取网络各层输入输出tensor容器,并定义module_name用于记录相应的module名字

    module_name = []features_in_hook = []features_out_hook = []

    hook函数负责将获取的输入输出添加到feature列表中,并提供相应的module名字

    def hook(module, fea_in, fea_out):    print("hooker working")    module_name.append(module.__class__)    features_in_hook.append(fea_in)    features_out_hook.append(fea_out)    return None

    定义全部是1的输入:

    x = torch.FloatTensor([[0.1, 0.1], [0.1, 0.1]])

    注册钩子可以对某些层单独进行:

    net = TestForHook()net_chilren = net.children()for child in net_chilren:    if not isinstance(child, nn.ReLU6):        child.register_forward_hook(hook=hook)

    测试网络输出:

    out, features_in_forward, features_out_forward = net(x)
    print("*"*5+"forward return features"+"*"*5)
    print(features_in_forward)
    print(features_out_forward)
    print("*"*5+"forward return features"+"*"*5)

    测试features_in是不是存储了输入:

    print("*"*5+"hook record features"+"*"*5)print(features_in_hook)print(features_out_hook)print(module_name)print("*"*5+"hook record features"+"*"*5)

    测试forward返回的feautes_in是不是和hook记录的一致:

    print("sub result")
    for forward_return, hook_record in zip(features_in_forward, features_in_hook):
        print(forward_return-hook_record[0])

    “pytorch中的hook机制是什么”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识可以关注编程网网站,小编将为大家输出更多高质量的实用文章!

    --结束END--

    本文标题: pytorch中的hook机制是什么

    本文链接: https://lsjlt.com/news/324325.html(转载时请注明来源链接)

    有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

    猜你喜欢
    • pytorch中的hook机制是什么
      本篇内容介绍了“pytorch中的hook机制是什么”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!1、hook背景Hook被成为钩子机制,这...
      99+
      2023-06-29
    • pytorch中的hook机制register_forward_hook
      目录1、hook背景2、源码阅读3、定义一个用于测试hooker的类4、定义hook函数5、对需要的层注册hook6、测试forward()返回的特征和hook记录的是否一致6.1测...
      99+
      2024-04-02
    • tp5框架中的hook机制是什么
      这篇文章主要介绍tp5框架中的hook机制是什么,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!1. 官方解释行为(Behavior)是ThinkPHP扩展机制中比较关键的一项扩展,行为既可以独立调用,也可以绑定到某个...
      99+
      2023-06-15
    • 什么是PyTorch中的自动微分机制
      PyTorch中的自动微分机制是指PyTorch自带的自动求导功能,它可以自动计算神经网络中每个参数的梯度,从而实现反向传播和优化算...
      99+
      2024-03-05
      PyTorch
    • PyTorch自动求导机制是什么
      PyTorch的自动求导机制是指PyTorch能够自动计算张量的梯度,即张量的导数。这个机制使得使用PyTorch进行深度学习模型的...
      99+
      2024-03-05
      PyTorch
    • React的Hook是什么
      这篇文章主要介绍了React的Hook是什么,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。State Hook这个例子用来显示一个计数器。当你点击按钮,计数器的值就会增加:i...
      99+
      2023-06-29
    • react中hook的概念是什么
      本文小编为大家详细介绍“react中hook的概念是什么”,内容详细,步骤清晰,细节处理妥当,希望这篇“react中hook的概念是什么”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识...
      99+
      2024-04-02
    • Vue3中的Hook特性是什么
      这篇文章主要讲解了“Vue3中的Hook特性是什么”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“Vue3中的Hook特性是什么”吧!Hook 的概念Hook...
      99+
      2024-04-02
    • Android中的HOOK技术是什么
      目录1. 什么是 Hook2. Hook的应用场景3. Hook的技术方式或框架4. Hook的一般步骤和技巧实战1. 什么是 Hook Hook 英文翻译过来就是「钩子」的意思,那...
      99+
      2023-02-17
      Android HOOK技术 Android HOOK框架
    • Flex中Hook机制的示例分析
      小编给大家分享一下Flex中Hook机制的示例分析,希望大家阅读完这篇文章之后都有所收获,下面让我们一起去探讨吧!在前一篇简要介绍了基于Flex的界面组合SDK,其中使用Hook机制实现UI Part生命周期管理、Master-Detail...
      99+
      2023-06-17
    • React Hook是什么
      这篇文章主要为大家展示了“React Hook是什么”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“React Hook是什么”这篇文章吧。 ...
      99+
      2024-04-02
    • React中常用的两个Hook是什么
      这篇文章给大家分享的是有关React中常用的两个Hook是什么的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。先介绍一下什么是hookHook是React 16.8新增的特性,专门...
      99+
      2024-04-02
    • Python中hook的实现原理是什么
      在Python中,hook(钩子)是一种机制,允许开发者在特定事件(例如函数调用、异常发生等)发生时插入自定义的代码进行处理。实现原...
      99+
      2023-09-26
      Python
    • PyTorch中的张量是什么
      在PyTorch中,张量是一种类似于多维数组的数据结构,可以存储和处理多维数据。张量在PyTorch中是用来表示神经网络的输入、输出...
      99+
      2024-03-05
      PyTorch
    • Java中的锁机制是什么
      今天小编给大家分享一下Java中的锁机制是什么的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。Java中的锁机制是保证多线程并...
      99+
      2023-07-05
    • MySQL中的锁机制是什么
      这篇“MySQL中的锁机制是什么”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“MySQL中的锁机制是什么”文章吧。一.概述锁...
      99+
      2023-07-05
    • Java中的SPI机制是什么
      这篇文章主要介绍“Java中的SPI机制是什么”,在日常操作中,相信很多人在Java中的SPI机制是什么问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”Java中的SPI机制是什么”的疑惑有所帮助!接下来,请跟...
      99+
      2023-07-05
    • ZooKeeper中的Watch机制是什么
      ZooKeeper中的Watch机制是一种事件监听机制,用于通知客户端关于特定节点的状态变化。当客户端对某个节点注册了Watch事件...
      99+
      2024-03-06
      ZooKeeper
    • PostgreSQL中的锁机制是什么
      PostgreSQL中的锁机制是用来控制并发访问数据库中数据的方式。它可以防止多个会话同时对同一数据进行修改,从而避免数据不一致的问...
      99+
      2024-04-02
    • MySQL中复制机制的原理是什么
      MySQL中复制机制的原理是什么,很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。背景介绍复制,就是对数据的完整拷贝,说到为什么要...
      99+
      2024-04-02
    软考高级职称资格查询
    编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
    • 官方手机版

    • 微信公众号

    • 商务合作