Python 官方文档:入门教程 => 点击学习
# -*- encoding: utf-8 -*-'''@File : resnet_with_attention.py@Time : 2023/03/25 08:55:30@Author : RainfyLee
# -*- encoding: utf-8 -*-'''@File : resnet_with_attention.py@Time : 2023/03/25 08:55:30@Author : RainfyLee @Version : 1.0@Contact : 379814385@qq.com''' # here put the import lib import torchfrom mmdet.models.backbones import ResNetfrom fightinGCv_attention.attention.CoordAttention import CoordAttfrom fightingcv_attention.attention.SEAttention import SEAttentionfrom mmdet.models.builder import BACKBONES # 定义带attention的resnet18基类class ResNetWithAttention(ResNet): def __init__(self , **kwargs): super(ResNetWithAttention, self).__init__(**kwargs) # 目前将注意力模块加在最后的三个输出特征层 # resnet输出四个特征层 if self.depth in (18, 34): self.dims = (64, 128, 256, 512) elif self.depth in (50, 101, 152): self.dims = (256, 512, 1024, 2048) else: raise Exception() self.attention1 = self.get_attention_module(self.dims[1]) self.attention2 = self.get_attention_module(self.dims[2]) self.attention3 = self.get_attention_module(self.dims[3]) # 子类只需要实现该attention即可 def get_attention_module(self, dim): raise NotImplementedError() def forward(self, x): outs = super().forward(x) outs = list(outs) outs[1] = self.attention1(outs[1]) outs[2] = self.attention2(outs[2]) outs[3] = self.attention3(outs[3]) outs = tuple(outs) return outs @BACKBONES.reGISter_module()class ResNetWithCoordAttention(ResNetWithAttention): def __init__(self , **kwargs): super(ResNetWithCoordAttention, self).__init__(**kwargs) # 子类只需要实现该attention即可 def get_attention_module(self, dim): return CoordAtt(inp=dim, oup=dim, reduction=32) @BACKBONES.register_module()class ResNetWithSEAttention(ResNetWithAttention): def __init__(self , **kwargs): super(ResNetWithSEAttention, self).__init__(**kwargs) # 子类只需要实现该attention即可 def get_attention_module(self, dim): return SEAttention(channel=dim, reduction=16) if __name__ == "__main__": # model = ResNet(depth=18) # model = ResNet(depth=34) # model = ResNet(depth=50) # model = ResNet(depth=101) # model = ResNet(depth=152) # model = ResNetWithCoordAttention(depth=18) model = ResNetWithSEAttention(depth=18) x = torch.rand(1, 3, 224, 224) outs = model(x) # print(outs.shape) for i, out in enumerate(outs): print(i, out.shape)
以resnet为例子,我在多个尺度的特征层输出增加注意力机制,以此编写一个基类,子类只需要实现这个attention即可。
参考开源仓库实现attention:
当然也可以直接pip调用:
pip install fightingcv-attention
测试完模型输出后可以利用注册到mmdetection:
简单的方法是,添加backbone注册修饰器,并在train.py和test.py中,import 该文件。
在配置上将model的type从Resnet更改为ResNetWithSEAttention或者ResNetWithSEAttention即可。
来源地址:https://blog.csdn.net/qq_21904447/article/details/129762735
--结束END--
本文标题: MMDetection中对Resnet增加注意力机制Attention的简单方法
本文链接: https://lsjlt.com/news/400233.html(转载时请注明来源链接)
有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341
2024-03-01
2024-03-01
2024-03-01
2024-02-29
2024-02-29
2024-02-29
2024-02-29
2024-02-29
2024-02-29
2024-02-29
回答
回答
回答
回答
回答
回答
回答
回答
回答
回答
0