Python 官方文档:入门教程 => 点击学习
目录参数注册nn.ModuleList和nn.ModuleDict总结参考自官方文档 参数注册 尝试自己写GoogLeNet时碰到的问题,放在字典中的参数无法自动注册,所谓的注册,就
参考自官方文档
尝试自己写GoogLeNet时碰到的问题,放在字典中的参数无法自动注册,所谓的注册,就是当参数注册到这个网络上时,它会随着你在外部调用net.cuda()后自动迁移到GPU上,而没有注册的参数则不会随着网络迁到GPU上,这就可能导致输入在GPU上而参数不在GPU上,从而出现错误,为了说明这个现象。
举一个有点铁憨憨的例子:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.weight = torch.rand((3,4)) # 这里其实可以直接用nn.Linear,但为了举例这里先憨憨一下
def forward(self,x):
return F.linear(x,self.weight)
if __name__ == "__main__":
batch_size = 10
dummy = torch.rand((batch_size,4))
net = Net()
print(net(dummy))
上面的代码可以成功运行,因为所有的数值都是放在CPU上的,但是,一旦我们要把模型移到GPU上时
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.weight = torch.rand((3,4))
def forward(self,x):
return F.linear(x,self.weight)
if __name__ == "__main__":
batch_size = 10
dummy = torch.rand((batch_size,4)).cuda()
net = Net().cuda()
print(net(dummy))
运行后就会出现
...
RuntimeError: Expected object of backend CUDA but got backend CPU for argument #2 'mat2'
这就是因为self.weight没有随着模型一起移到GPU上的原因,此时我们查看模型的参数,会发现并没有self.weight
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.weight = torch.rand((3,4))
def forward(self,x):
return F.linear(x,self.weight)
if __name__ == "__main__":
net = Net()
for parameter in net.parameters():
print(parameter)
上面的代码没有输出,因为net根本没有参数
那么为了让net有参数,我们需要手动地将self.weight注册到网络上
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.weight = nn.Parameter(torch.rand((3,4))) # 被注册的参数必须是nn.Parameter类型
self.reGISter_parameter('weight',self.weight) # 手动注册参数
def forward(self,x):
return F.linear(x,self.weight)
if __name__ == "__main__":
net = Net()
for parameter in net.parameters():
print(parameter)
batch_size = 10
net = net.cuda()
dummy = torch.rand((batch_size,4)).cuda()
print(net(dummy))
此时网络的参数就有了输出,同时会随着一起迁到GPU上,输出就类似这样
Parameter containing:
tensor([...])
tensor([...])
不过后来我实验了以下,好像只写nn.Parameter不写register也可以被默认注册
有时候我们为了图省事,可能会这样写网络
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.linears = [nn.Linear(4,4),nn.Linear(4,4),nn.Linear(4,2)]
def forward(self,x):
for linear in self.linears:
x = linear(x)
x = F.relu(x)
return x
if __name__ == '__main__':
net = Net()
for parameter in net.parameters():
print(parameter)
同样,输出网络的参数啥也没有,这意味着当调用net.cuda时,self.linears里面的参数不会一起走到GPU上去
此时我们可以在__init__方法中手动对self.parameters()迭代然后把每个参数注册,但更好的方法是,PyTorch已经为我们提供了nn.ModuleList,用来代替python内置的list,放在nn.ModuleList中的参数将会自动被正确注册
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.linears = nn.ModuleList([nn.Linear(4,4),nn.Linear(4,4),nn.Linear(4,2)])
def forward(self,x):
for linear in self.linears:
x = linear(x)
x = F.relu(x)
return x
if __name__ == '__main__':
net = Net()
for parameter in net.parameters():
print(parameter)
此时就有输出了
Parameter containing:
tensor(...)
Parameter containing:
tensor(...)
...
nn.ModuleDict也是类似,当我们需要把参数放在一个字典里的时候,能够用的上,这里直接给一个官方的例子看一看就OK
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.ModuleDict([
['lrelu', nn.LeakyReLU()],
['prelu', nn.PReLU()]
])
def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x
需要注意的是,虽然直接放在Python list中的参数不会自动注册,但如果只是暂时放在list里,随后又调用了nn.Sequential把整个list整合起来,参数仍然是会自动注册的
另外一点要注意的是ModuleList和ModuleDict里面只能放Module的子类,也就是nn.Conv,nn.Linear这样的,但不能放nn.Parameter,如果要放nn.Parameter,用nn.ParameterList即可,用法和nn.ModuleList一样
以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。
--结束END--
本文标题: Pytorch参数注册和nn.ModuleListnn.ModuleDict的问题
本文链接: https://lsjlt.com/news/176438.html(转载时请注明来源链接)
有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341
2024-03-01
2024-03-01
2024-03-01
2024-02-29
2024-02-29
2024-02-29
2024-02-29
2024-02-29
2024-02-29
2024-02-29
回答
回答
回答
回答
回答
回答
回答
回答
回答
回答
0