返回顶部
首页 > 资讯 > 后端开发 > Python >MMDetection中对Resnet增加注意力机制Attention的简单方法
  • 628
分享到

MMDetection中对Resnet增加注意力机制Attention的简单方法

python深度学习开发语言 2023-09-08 16:09:11 628人浏览 安东尼

Python 官方文档:入门教程 => 点击学习

摘要

# -*- encoding: utf-8 -*-'''@File : resnet_with_attention.py@Time : 2023/03/25 08:55:30@Author : RainfyLee

# -*- encoding: utf-8 -*-'''@File    :   resnet_with_attention.py@Time    :   2023/03/25 08:55:30@Author  :   RainfyLee @Version :   1.0@Contact :   379814385@qq.com''' # here put the import lib import torchfrom mmdet.models.backbones import ResNetfrom fightinGCv_attention.attention.CoordAttention import CoordAttfrom fightingcv_attention.attention.SEAttention import SEAttentionfrom mmdet.models.builder import BACKBONES  # 定义带attention的resnet18基类class ResNetWithAttention(ResNet):    def __init__(self , **kwargs):        super(ResNetWithAttention, self).__init__(**kwargs)        # 目前将注意力模块加在最后的三个输出特征层        # resnet输出四个特征层        if self.depth in (18, 34):            self.dims = (64, 128, 256, 512)        elif self.depth in (50, 101, 152):            self.dims = (256, 512, 1024, 2048)        else:            raise Exception()        self.attention1 = self.get_attention_module(self.dims[1])             self.attention2 = self.get_attention_module(self.dims[2])             self.attention3 = self.get_attention_module(self.dims[3])             # 子类只需要实现该attention即可    def get_attention_module(self, dim):        raise NotImplementedError()        def forward(self, x):        outs = super().forward(x)        outs = list(outs)        outs[1] = self.attention1(outs[1])        outs[2] = self.attention2(outs[2])        outs[3] = self.attention3(outs[3])            outs = tuple(outs)        return outs    @BACKBONES.reGISter_module()class ResNetWithCoordAttention(ResNetWithAttention):    def __init__(self , **kwargs):        super(ResNetWithCoordAttention, self).__init__(**kwargs)     # 子类只需要实现该attention即可    def get_attention_module(self, dim):        return CoordAtt(inp=dim, oup=dim, reduction=32)    @BACKBONES.register_module()class ResNetWithSEAttention(ResNetWithAttention):    def __init__(self , **kwargs):        super(ResNetWithSEAttention, self).__init__(**kwargs)     # 子类只需要实现该attention即可    def get_attention_module(self, dim):        return SEAttention(channel=dim, reduction=16)  if __name__ == "__main__":    # model = ResNet(depth=18)    # model = ResNet(depth=34)    # model = ResNet(depth=50)    # model = ResNet(depth=101)        # model = ResNet(depth=152)    # model = ResNetWithCoordAttention(depth=18)    model = ResNetWithSEAttention(depth=18)    x = torch.rand(1, 3, 224, 224)    outs = model(x)    # print(outs.shape)    for i, out in enumerate(outs):        print(i, out.shape)

以resnet为例子,我在多个尺度的特征层输出增加注意力机制,以此编写一个基类,子类只需要实现这个attention即可。

参考开源仓库实现attention:

GitHub - xmu-xiaoma666/External-Attention-pytorch: 🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

当然也可以直接pip调用:

pip install fightingcv-attention

测试完模型输出后可以利用注册到mmdetection:

简单的方法是,添加backbone注册修饰器,并在train.py和test.py中,import 该文件。

在配置上将model的type从Resnet更改为ResNetWithSEAttention或者ResNetWithSEAttention即可。

来源地址:https://blog.csdn.net/qq_21904447/article/details/129762735

--结束END--

本文标题: MMDetection中对Resnet增加注意力机制Attention的简单方法

本文链接: https://lsjlt.com/news/400233.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作