返回顶部
首页 > 资讯 > 精选 >如何在pytorch中解决state_dict()的拷贝问题
  • 952
分享到

如何在pytorch中解决state_dict()的拷贝问题

2023-06-06 17:06:59 952人浏览 泡泡鱼
摘要

如何在PyTorch中解决state_dict()的拷贝问题?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。model.state_dict()是浅拷贝,返回的参

如何在PyTorch中解决state_dict()的拷贝问题?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了

import torchimport torch.nn as nnimport torchvisionimport numpy as npfrom torchsummary import summary# Define modelclass TheModelClass(nn.Module):  def __init__(self):    super(TheModelClass, self).__init__()    self.conv1 = nn.Conv2d(3, 6, 5)    self.pool = nn.MaxPool2d(2, 2)    self.conv2 = nn.Conv2d(6, 16, 5)    self.fc1 = nn.Linear(16 * 5 * 5, 120)    self.fc2 = nn.Linear(120, 84)    self.fc3 = nn.Linear(84, 10)  def forward(self, x):    x = self.pool(F.relu(self.conv1(x)))    x = self.pool(F.relu(self.conv2(x)))    x = x.view(-1, 16 * 5 * 5)    x = F.relu(self.fc1(x))    x = F.relu(self.fc2(x))    x = self.fc3(x)    return x# Initialize modelmodel = TheModelClass()# Initialize optimizeroptimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# Print model's state_dictprint("Model's state_dict:")for param_tensor in model.state_dict():  print(param_tensor,"\t", model.state_dict()[param_tensor].size())# Print optimizer's state_dictprint("Optimizer's state_dict:")for var_name in optimizer.state_dict():  print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:conv1.weight  torch.Size([6, 3, 5, 5])conv1.bias  torch.Size([6])conv2.weight  torch.Size([16, 6, 5, 5])conv2.bias  torch.Size([16])fc1.weight  torch.Size([120, 400])fc1.bias  torch.Size([120])fc2.weight  torch.Size([84, 120])fc2.bias  torch.Size([84])fc3.weight  torch.Size([10, 84])fc3.bias  torch.Size([10])Optimizer's state_dict:state  {}param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:

net=BaseNet()

保存net时报错 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:

class BaseNet(object):  def __init__(self):

把类定义改为

class BaseNet(nn.Module):  def __init__(self):    super(BaseNet, self).__init__()

看完上述内容是否对您有帮助呢?如果还想对相关知识有进一步的了解或阅读更多相关文章,请关注编程网精选频道,感谢您对编程网的支持。

--结束END--

本文标题: 如何在pytorch中解决state_dict()的拷贝问题

本文链接: https://lsjlt.com/news/247970.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • 如何在pytorch中解决state_dict()的拷贝问题
    如何在pytorch中解决state_dict()的拷贝问题?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。model.state_dict()是浅拷贝,返回的参...
    99+
    2023-06-06
  • 如何解析Python深拷贝与浅拷贝问题
    这篇文章将为大家详细讲解有关如何解析Python深拷贝与浅拷贝问题,文章内容质量较高,因此小编分享给大家做个参考,希望大家阅读完这篇文章后对相关知识有一定的了解。在平时工作中,经常涉及到数据的传递,在数据传递使用过程中,可能会发生数据被修改...
    99+
    2023-06-16
  • 如何解决Linux系统之间拷贝文件的问题
    小编给大家分享一下如何解决Linux系统之间拷贝文件的问题,希望大家阅读完这篇文章之后都有所收获,下面让我们一起去探讨吧!  第一种方法  首先,无论本地还是远程,需要移动或拷贝的文件较多且都不太大时,用cp命令和mv命令效率较低,可以先使...
    99+
    2023-06-13
  • 如何理解JavaScript中的浅拷贝与深拷贝
    本篇文章给大家分享的是有关如何理解JavaScript中的浅拷贝与深拷贝,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。 浅拷贝在使用JavaScript对数组进行操作...
    99+
    2023-06-16
  • 如何解析Python中的赋值、浅拷贝和深拷贝
    这篇文章给大家介绍如何解析Python中的赋值、浅拷贝和深拷贝,内容非常详细,感兴趣的小伙伴们可以参考借鉴,希望对大家能有所帮助。先明确几点不可变类型:该数据类型对象所指定内存中的值不可以被改变。(1)、在改变某个对象的值时,由于其内存中的...
    99+
    2023-06-22
  • 探讨Java中的深浅拷贝问题
    目录一、前言二、浅拷贝三、深拷贝一、前言 拷贝这个词想必大家都很熟悉,在工作中经常需要拷贝一份文件作为副本。拷贝的好处也很明显,相较于新建来说,可以节省很大的工作量。在Java中,同...
    99+
    2024-04-02
  • 如何解决Pytorch中Batch Normalization layer的问题
    小编给大家分享一下如何解决Pytorch中Batch Normalization layer的问题,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!1. 注意mome...
    99+
    2023-06-15
  • Golang中的深拷贝与浅拷贝如何使用
    本篇内容主要讲解“Golang中的深拷贝与浅拷贝如何使用”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“Golang中的深拷贝与浅拷贝如何使用”吧!一、概念1、深拷贝(Deep Copy)拷贝的是...
    99+
    2023-07-05
  • 在pytorch中复制模型时出现问题如何解决
    在pytorch中复制模型时出现问题如何解决?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。直接使用model2=model1会出现当更新model2时,model1的权重也...
    99+
    2023-06-06
  • JavaScript 深拷贝的循环引用问题详解
    如果说道实现深拷贝最简单的方法,我们第一个想到的就是 JSON.stringify() 方法,因为JSON.stringify()后返回的是字符串,所以我们会再使用JSON.pars...
    99+
    2022-12-27
    JavaScript 深拷贝 JavaScript 深拷贝循环引用 JS循环引用
  • 如何分析web前端中的深拷贝和浅拷贝
    小编今天带大家了解如何分析web前端中的深拷贝和浅拷贝,文中知识点介绍的非常详细。觉得有帮助的朋友可以跟着小编一起浏览文章的内容,希望能够帮助更多想解决这个问题的朋友找到问题的答案,下面跟着小编一起深入学习“如何分析web前端中的深拷贝和浅...
    99+
    2023-06-05
  • Jetson NX配置pytorch的问题如何解决
    这篇文章主要介绍“Jetson NX配置pytorch的问题如何解决”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“Jetson NX配置pytorch的问题如何解决”文章能帮助大...
    99+
    2023-07-05
  • JavaScript中的深拷贝如何实现
    今天小编给大家分享一下JavaScript中的深拷贝如何实现的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。深拷贝的最终实现这...
    99+
    2023-07-04
  • linux下scp远程拷贝包含空格的目录或者文件的问题如何解决
    本篇内容介绍了“linux下scp远程拷贝包含空格的目录或者文件的问题如何解决”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!描述: 今天需要...
    99+
    2023-06-13
  • 解决pytorch中的kl divergence计算问题
    偶然从pytorch讨论论坛中看到的一个问题,KL divergence different results from tf,kl divergence 在TensorFlow中和p...
    99+
    2024-04-02
  • MapStruct对象映射转换解决Bean属性拷贝性能问题
    目录简介适用场景工作时机使用案例1、添加依赖2、定义两个类3、单元测试核心总结简介 MapStruct 是一个代码生成器(可以生成对象映射转换的代码),它基于约定优于配置的方法,极大...
    99+
    2024-04-02
  • 如何解决pytorch显存一直变大的问题
    本篇内容介绍了“如何解决pytorch显存一直变大的问题”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!在代码中添加以下两行可以解决:torc...
    99+
    2023-06-14
  • docker容器内拷贝文件失败如何解决
    拷贝文件失败的原因可能有多种,以下是一些常见的解决方法:1. 检查文件路径:确认文件路径是否正确,包括容器内的路径和宿主机的路径,尤...
    99+
    2023-10-19
    docker
  • 在python中如何解决死锁的问题
    这篇文章将为大家详细讲解有关在python中如何解决死锁的问题,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。1.添加超时时间:fromthreading import Thread,&...
    99+
    2023-06-14
  • PyTorch中怎么解决过拟合的问题
    PyTorch中解决过拟合问题的方法有很多种,以下是一些常用的方法: 正则化:在损失函数中添加正则项,如L1正则化或L2正则化,...
    99+
    2024-03-05
    PyTorch
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作