返回顶部
首页 > 资讯 > 后端开发 > Python >pytorch教程resnet.py的实现文件源码分析
  • 631
分享到

pytorch教程resnet.py的实现文件源码分析

2024-04-02 19:04:59 631人浏览 八月长安

Python 官方文档:入门教程 => 点击学习

摘要

目录调用PyTorch内置的模型的方法解读模型源码Resnet.py包含的库文件该库定义了6种Resnet的网络结构每种网络都有训练好的可以直接用的.pth参数文件Resnet中大多

调用pytorch内置的模型的方法


import torchvision
model = torchvision.models.resnet50(pretrained=True)

这样就导入了resnet50的预训练模型了。如果只需要网络结构,不需要用预训练模型的参数来初始化

那么就是:


model = torchvision.models.resnet50(pretrained=False)

如果要导入densenet模型也是同样的道理

比如导入densenet169,且不需要是预训练的模型:


model = torchvision.models.densenet169(pretrained=False)

由于pretrained参数默认是False,所以等价于:


model = torchvision.models.densenet169()

不过为了代码清晰,最好还是加上参数赋值。

解读模型源码Resnet.py

包含的库文件


import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo

该库定义了6种Resnet的网络结构

包括


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50',  'resnet101',  'resnet152']

每种网络都有训练好的可以直接用的.pth参数文件


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50',  'resnet101',  'resnet152']

Resnet中大多使用3*3的卷积定义如下


def conv3x3(in_planes, out_planes, stride=1):   
"""3x3 convolution with padding"""   
return nn.Conv2d(in_planes, out_planes, kernel_size=3, 
stride=stride, padding=1, bias=False)

该函数继承自nn网络中的2维卷积,这样做主要是为了方便,少写参数参数由原来的6个变成了3个

输出图与输入图长宽保持一致

如何定义不同大小的Resnet网络

Resnet类是一个基类,
所谓的"Resnet18", ‘resnet34', ‘resnet50', ‘resnet101', 'resnet152'只是Resnet类初始化的时候使用了不同的参数,理论上我们可以根据Resnet类定义任意大小的Resnet网络
下面先看看这些不同大小的Resnet网络是如何定义的

定义Resnet18


def resnet18(pretrained=False, **kwargs):  
"""
Constructs a ResNet-18 model.    
Args:    
pretrained (bool):If True, returns a model pre-trained on ImageNet   
"""    
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)    
if pretrained:        
    model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))   
 return model

定义Resnet34


def resnet34(pretrained=False, **kwargs):    
"""Constructs a ResNet-34 model.   
Args:        pretrained (bool): If True, returns a model pre-trained on ImageNet    """   
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)  
if pretrained:        
    model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))    
 return model

我们发现Resnet18和Resnet34的定义几乎是一样的,下面我们把Resnet18,Resnet34,Resnet50,Resnet101,Resnet152,不一样的部分写在一块进行对比


model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)    #Resnet18
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)    #Resnet34
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)    #Eesnt50
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)  #Resnet101
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)  #Resnet152

代码看起来非常的简洁工整,

其他resnet18、resnet101等函数和resnet18基本类似,差别主要是在:

1、构建网络结构的时候block的参数不一样,比如resnet18中是[2, 2, 2, 2],resnet101中是[3, 4, 23, 3]。

2、调用的block类不一样,比如在resnet50、resnet101、resnet152中调用的是Bottleneck类,而在resnet18和resnet34中调用的是BasicBlock类,这两个类的区别主要是在residual结果中卷积层的数量不同,这个是和网络结构相关的,后面会详细介绍。

3、如果下载预训练模型的话,model_urls字典的键不一样,对应不同的预训练模型。因此接下来分别看看如何构建网络结构和如何导入预训练模型。

Resnet类

构建ResNet网络是通过ResNet这个类进行的。ResNet类是继承PyTorch中网络的基类:torch.nn.Module。

构建Resnet类主要在于重写 init() forward() 方法。

我们构建的所有网络比如:VGGAlexnet等都需要重写这两个方法,这两个方法很重要

看起来Resne类是整个文档的核心

下面我们就要研究一下Resnet基类是如何实现的

Resnet类采用了pytorch定义网络模型的标准结构,包含

iinit()方法: 定义了网络的各个层
forward()方法: 定义了前向传播过程

这两个方法的用法,这个可以查看pytorch的官方文档就可以明白

在Resnet类中,还包含一个自定义的方法make_layer()方法

是用来构建ResNet网络中的4个blocks

_make_layer方法的第一个输入block是BottleneckBasicBlock

第二个输入是该blocks的输出channel

第三个输入是每个blocks中包含多少个residual子结构,因此layers这个列表就是前面resnet50的[3, 4, 6, 3]。

_make_layer方法中比较重要的两行代码是:


layers.append(block(self.inplanes, planes, stride, downsample))

该部分是将每个blocks的第一个residual结构保存在layers列表中。


 for i in range(1, blocks): layers.append(block(self.inplanes, planes))

该部分是将每个blocks的剩下residual 结构保存在layers列表中,这样就完成了一个blocks的构造。这两行代码中都是通过Bottleneck这个类来完成每个residual的构建

接下来介绍Bottleneck类


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNORM2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

下面我们分别看看这两个过程:

网络的forward过程


 def forward(self, x):                                #x代表输入
        x = self.conv1(x)                             #进过卷积层1
        x = self.bn1(x)                                #bn1层
        x = self.relu(x)                                #relu激活
        x = self.maxpool(x)                         #最大池化
        x = self.layer1(x)                            #卷积块1
        x = self.layer2(x)                           #卷积块2
        x = self.layer3(x)                          #卷积块3
        x = self.layer4(x)                          #卷积块4
        x = self.avgpool(x)                     #平均池化
        x = x.view(x.size(0), -1)               #二维变成变成一维向量
        x = self.fc(x)                             #全连接层
        return x

里面的大部分我们都可以理解,只有layer1-layer4是Resnet网络自己定义的,
它也是Resnet残差连接的精髓所在,我们来分析一下layer层是怎么实现的

残差Block连接是如何实现的

从前面的ResNet类可以看出,在构造ResNet网络的时候,最重要的是 BasicBlock这个类,因为ResNet是由residual结构组成的,而 BasicBlock类就是完成residual结构的构建。同样 BasicBlock还是继承了torch.nn.Module类,且重写了__init__()和forward()方法。从forward方法可以看出,bottleneck就是我们熟悉的3个主要的卷积层、BN层和激活层,最后的out += residual就是element-wise add的操作。

这部分在 BasicBlock类中实现,我们看看这层是如何前向传播的


def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

我画个流程图来表示一下

在这里插入图片描述

画的比较丑,不过基本意思在里面了,

根据论文的描述,x是否需要下采样由x与out是否大小一样决定,

假如进过conv2和bn2后的结果我们称之为 P

假设x的大小为wHchannel1

如果P的大小也是wHchannel1

则无需下采样
out = relu(P + X)
out的大小为W * H *(channel1+channel2),

如果P的大小是W/2 * H/2 * channel

则X需要下采样后才能与P相加,
out = relu(P+ X下采样)
out的大小为W/2 * H/2 * (channel1+channel2)

BasicBlock类和Bottleneck类类似,前者主要是用来构建ResNet18和ResNet34网络,因为这两个网络的residual结构只包含两个卷积层,没有Bottleneck类中的bottleneck概念。因此在该类中,第一个卷积层采用的是kernel_size=3的卷积,就是我们之前提到的conv3x3函数。

下面是BasicBlock类的完整代码


class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

以上就是pytorch教程resnet.py的实现文件源码解读的详细内容,更多关于pytorch源码解读的资料请关注编程网其它相关文章!

--结束END--

本文标题: pytorch教程resnet.py的实现文件源码分析

本文链接: https://lsjlt.com/news/134933.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • pytorch教程resnet.py的实现文件源码分析
    目录调用pytorch内置的模型的方法解读模型源码Resnet.py包含的库文件该库定义了6种Resnet的网络结构每种网络都有训练好的可以直接用的.pth参数文件Resnet中大多...
    99+
    2024-04-02
  • Pytorch教程内置模型源码实现
    翻译自 https://pytorch.org/docs/stable/torchvision/models.html 主要讲解了torchvision.models的使用 torc...
    99+
    2024-04-02
  • pytorch实践线性模型3d源码分析
    这篇文章主要介绍“pytorch实践线性模型3d源码分析”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“pytorch实践线性模型3d源码分析”文章能帮助大家解决问题。y = wx +b通过meshg...
    99+
    2023-07-06
  • ZooKeeper框架教程Curator分布式锁实现及源码分析
    目录  如何使用InterProcessMutex  实现思路   代码实现概述  InterProcessMutex源码分析&nb...
    99+
    2024-04-02
  • React实现合成事件的源码分析
    目录事件绑定事件触发结尾今天尝试学习 React 事件的源码实现。 React 版本为 18.2.0 React 中的事件,是对原生事件的封装,叫做合成事件。抽象出一层合成事件,是为...
    99+
    2022-12-08
    React实现合成事件 React合成事件
  • 深入分析GolangServer源码实现过程
    func (srv *Server) Serve(l net.Listener) error { ...... for { rw, err := l.Accept() i...
    99+
    2023-02-02
    Go Server Go Server源码
  • java线程池的实现原理源码分析
    这篇文章主要介绍“java线程池的实现原理源码分析”,在日常操作中,相信很多人在java线程池的实现原理源码分析问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”java线程池的实现原理源码分析”的疑惑有所帮助!...
    99+
    2023-06-30
  • RocketMQ broker文件清理源码分析
    本篇内容介绍了“RocketMQ broker文件清理源码分析”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!1. broker 清...
    99+
    2023-07-05
  • 亲自教你实现栈及C#中Stack源码分析
    定义 栈又名堆栈,是一种操作受限的线性表,仅能在表尾进行插入和删除操作。 它的特点是先进后出,就好比我们往桶里面放盘子,放的时候都是从下往上一个一个放(入栈),取的时候只能从上往下一...
    99+
    2024-04-02
  • Linux中文件系统truncate.c源码分析
    这篇文章主要讲解了“Linux中文件系统truncate.c源码分析”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“Linux中文件系统truncate.c源码分析”吧!Linux-0.11 ...
    99+
    2023-07-05
  • python文件读写操作源码分析
    本篇内容介绍了“python文件读写操作源码分析”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!文件写操作的案例# 打开文件(只写模...
    99+
    2023-07-05
  • SpringBoot拦截器与文件上传实现方法与源码分析
    目录一、拦截器1、创建一个拦截器2、配置拦截器二、拦截器原理三、文件上传四、文件上传流程一、拦截器 拦截器我们之前在springmvc已经做过介绍了 大家可以看下【SpringMVC...
    99+
    2024-04-02
  • 如何分析Go语言的库源码文件
    这期内容当中小编将会给大家带来有关如何分析Go语言的库源码文件,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。go适合做什么go是golang...
    99+
    2024-04-02
  • 源码分析Django的message组件
    目录Django的Message组件(源码分析)1. 配置2. 设置值3. 读取值4. 源码分析4.1第一步: 设置值4.2 第二步: 读取值Django的Message组件(源码分...
    99+
    2023-05-18
    Django message组件 Django message
  • vue3源码分析reactivity实现原理
    目录引言第一部分:简单版reactivity(1).实现reactive和effect(2).实现ref(3).实现computed第二部分:深入分析对于object、array的响...
    99+
    2023-01-28
    vue3源码分析reactivity vue reactivity
  • python源文件中字符编码的示例分析
    这篇文章将为大家详细讲解有关python源文件中字符编码的示例分析,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。Python的优点有哪些1、简单易用,与C/C++、Java、C# 等传统语言相比,Pyth...
    99+
    2023-06-14
  • SpringBoot源码分析之bootstrap.properties文件加载的原理
    目录1.bootstrap的使用2.bootstrap加载原理分析2.1 BootstrapApplicationListener2.2 启动流程梳理2.3 bootstrap.pr...
    99+
    2024-04-02
  • python密码学文件解密实现教程
    目录代码输出在本章中,我们将讨论使用Python解密加密文件.请注意,对于解密过程,我们将遵循相同的过程,但不是指定输出路径,而是关注输入路径或加密的必要文件. 代码 以下是使用Py...
    99+
    2024-04-02
  • python密码学实现文件加密教程
    目录代码输出说明在Python中,可以在传输到通信通道之前加密和解密文件.为此,您必须使用插件 PyCrypto .您可以使用下面给出的命令安装此插件. pip ...
    99+
    2024-04-02
  • redisson实现分布式锁的源码解析
    目录redisson测试代码加锁设计锁续期设计锁的自旋重试解锁设计撤销锁续期解锁成功唤排队线程 redisson redisson 实现分布式锁的机制如下: 依赖版本 implem...
    99+
    2024-04-02
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作