如何在PyTorch中解决state_dict()的拷贝问题?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。model.state_dict()是浅拷贝,返回的参
如何在PyTorch中解决state_dict()的拷贝问题?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。
model.state_dict()
是浅拷贝,返回的参数仍然会随着网络的训练而变化。应该使用deepcopy(model.state_dict())
,或将参数及时序列化到硬盘。
再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。
补充:pytorch中state_dict的理解
在PyTorch中,state_dict是一个python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。
import torchimport torch.nn as nnimport torchvisionimport numpy as npfrom torchsummary import summary# Define modelclass TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x# Initialize modelmodel = TheModelClass()# Initialize optimizeroptimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# Print model's state_dictprint("Model's state_dict:")for param_tensor in model.state_dict(): print(param_tensor,"\t", model.state_dict()[param_tensor].size())# Print optimizer's state_dictprint("Optimizer's state_dict:")for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name])
输出如下:
Model's state_dict:conv1.weight torch.Size([6, 3, 5, 5])conv1.bias torch.Size([6])conv2.weight torch.Size([16, 6, 5, 5])conv2.bias torch.Size([16])fc1.weight torch.Size([120, 400])fc1.bias torch.Size([120])fc2.weight torch.Size([84, 120])fc2.bias torch.Size([84])fc3.weight torch.Size([10, 84])fc3.bias torch.Size([10])Optimizer's state_dict:state {}param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]
我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!
补充:pytorch保存模型时报错***object has no attribute 'state_dict'
net=BaseNet()
保存net时报错 object has no attribute 'state_dict'
torch.save(net.state_dict(), models_dir)
原因是定义类的时候不是继承nn.Module类,比如:
class BaseNet(object): def __init__(self):
class BaseNet(nn.Module): def __init__(self): super(BaseNet, self).__init__()
看完上述内容是否对您有帮助呢?如果还想对相关知识有进一步的了解或阅读更多相关文章,请关注编程网精选频道,感谢您对编程网的支持。
--结束END--
本文标题: 如何在pytorch中解决state_dict()的拷贝问题
本文链接: https://lsjlt.com/news/247970.html(转载时请注明来源链接)
有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341
2024-05-24
2024-05-24
2024-05-24
2024-05-24
2024-05-24
2024-05-24
2024-05-24
2024-05-24
2024-05-24
2024-05-24
回答
回答
回答
回答
回答
回答
回答
回答
回答
回答
0