返回顶部
首页 > 资讯 > 后端开发 > Python >PyTorch 如何将CIFAR100数据按类标归类保存
  • 400
分享到

PyTorch 如何将CIFAR100数据按类标归类保存

2024-04-02 19:04:59 400人浏览 八月长安

Python 官方文档:入门教程 => 点击学习

摘要

few-shot learning的采样 Few-shot learning 基于任务对模型进行训练,在N-way-K-shot中,一个任务中的meta-training中含有N类,

few-shot learning的采样

Few-shot learning 基于任务对模型进行训练,在N-way-K-shot中,一个任务中的meta-training中含有N类,每一类抽取K个样本构成support set, query set则是在刚才抽取的N类剩余的样本中sample一定数量的样本(可以是均匀采样,也可以是不均匀采样)。

对数据按类标归类

针对上述情况,我们需要使用不同类别放置在不同文件夹的数据集。但有时,数据并没有按类放置,这时就需要对数据进行处理。

下面以CIFAR100为列(不含N-way-k-shot的采样):


import os
from skimage import io
import torchvision as tv
import numpy as np
import torch
def Cifar100(root):
    character = [[] for i in range(100)]
    train_set = tv.datasets.CIFAR100(root, train=True, download=True)
    test_set = tv.datasets.CIFAR100(root, train=False, download=True)
    dataset = []
    for (X, Y) in zip(train_set.train_data, train_set.train_labels):  # 将train_set的数据和label读入列表
        dataset.append(list((X, Y)))
    for (X, Y) in zip(test_set.test_data, test_set.test_labels):  # 将test_set的数据和label读入列表
        dataset.append(list((X, Y)))
    for X, Y in dataset:
        character[Y].append(X)  # 32*32*3
    character = np.array(character)
    character = torch.from_numpy(character)
    # 按类打乱
    np.random.seed(6)
    shuffle_class = np.arange(len(character))
    np.random.shuffle(shuffle_class)
    character = character[shuffle_class]
    # shape = self.character.shape
    # self.character = self.character.view(shape[0], shape[1], shape[4], shape[2], shape[3])  # 将数据转成channel在前
    meta_training, meta_validation, meta_testing = \
    character[:64], character[64:80], character[80:]  # meta_training : meta_validation : Meta_testing = 64类:16类:20类
    dataset = []  # 释放内存
    character = []
    os.mkdir(os.path.join(root, 'meta_training'))
    for i, per_class in enumerate(meta_training):
        character_path = os.path.join(root, 'meta_training', 'character_' + str(i))
        os.mkdir(character_path)
        for j, img in enumerate(per_class):
            img_path = character_path + '/' + str(j) + ".jpg"
            io.imsave(img_path, img)
    os.mkdir(os.path.join(root, 'meta_validation'))
    for i, per_class in enumerate(meta_validation):
        character_path = os.path.join(root, 'meta_validation', 'character_' + str(i))
        os.mkdir(character_path)
        for j, img in enumerate(per_class):
            img_path = character_path + '/' + str(j) + ".jpg"
            io.imsave(img_path, img)
    os.mkdir(os.path.join(root, 'meta_testing'))
    for i, per_class in enumerate(meta_testing):
        character_path = os.path.join(root, 'meta_testing', 'character_' + str(i))
        os.mkdir(character_path)
        for j, img in enumerate(per_class):
            img_path = character_path + '/' + str(j) + ".jpg"
            io.imsave(img_path, img)
if __name__ == '__main__':
    root = '/home/xie/文档/datasets/cifar_100'
    Cifar100(root)
    print("-----------------")

补充:使用Pytorch对数据集CIFAR-10进行分类

主要是以下几个步骤:

1、下载并预处理数据集

2、定义网络结构

3、定义损失函数和优化

4、训练网络并更新参数

5、测试网络效果


#数据加载和预处理
#使用CIFAR-10数据进行分类实验
import torch as t
import torchvision as tv
import torchvision.transfORMs as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage() # 可以把Tensor转成Image,方便可视化
 
#定义对数据的预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),  #归一化
])
 
#训练集
trainset = tv.datasets.CIFAR10(
    root = './data/',
    train = True,
    download = True,
    transform = transform
)
 
trainloader = t.utils.data.DataLoader(
    trainset,
    batch_size = 4,
    shuffle = True,
    num_workers = 2,
)
 
#测试集
testset = tv.datasets.CIFAR10(
    root = './data/',
    train = False,
    download = True,
    transform = transform,
)
testloader = t.utils.data.DataLoader(
    testset,
    batch_size = 4,
    shuffle = False,
    num_workers = 2,
)
 
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

初次下载需要一些时间,运行结束后,显示如下:


import torch.nn as nn
import torch.nn.functional as F
import time
start = time.time()#计时
#定义网络结构
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
        
    def forward(self,x):
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        
        x = x.view(x.size()[0],-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
net = Net()
print(net)

显示net结构如下:


#定义优化和损失
loss_func = nn.CrossEntropyLoss()  #交叉熵损失函数
optimizer = t.optim.SGD(net.parameters(),lr = 0.001,momentum = 0.9)
 
#训练网络
for epoch in range(2):
    running_loss = 0
    for i,data in enumerate(trainloader,0):
        inputs,labels = data
       
        outputs = net(inputs)
        loss = loss_func(outputs,labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss +=loss.item()
        if i%2000 ==1999:
            print('epoch:',epoch+1,'|i:',i+1,'|loss:%.3f'%(running_loss/2000))
            running_loss = 0.0
end = time.time()
time_using = end - start
print('finish training')
print('time:',time_using)

结果如下:

下一步进行使用测试集进行网络测试:


#测试网络
correct = 0 #定义的预测正确的图片数
total = 0#总共图片个数
with t.no_grad():
    for data in testloader:
        images,labels = data
        outputs = net(images)
        _,predict = t.max(outputs,1)
        total += labels.size(0)
        correct += (predict == labels).sum()
print('测试集中的准确率为:%d%%'%(100*correct/total))

结果如下:

简单的网络训练确实要比10%的比例高一点:)

在GPU中训练:


#在GPU中训练
device = t.device('cuda:0' if t.cuda.is_available() else 'cpu')
 
net.to(device)
images = images.to(device)
labels = labels.to(device)
 
output = net(images)
loss = loss_func(output,labels)
 
loss

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。如有错误或未考虑完全的地方,望不吝赐教。

--结束END--

本文标题: PyTorch 如何将CIFAR100数据按类标归类保存

本文链接: https://lsjlt.com/news/125482.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • PyTorch 如何将CIFAR100数据按类标归类保存
    few-shot learning的采样 Few-shot learning 基于任务对模型进行训练,在N-way-K-shot中,一个任务中的meta-training中含有N类,...
    99+
    2024-04-02
  • pytorch中如何保存tensor数据
    在PyTorch中,可以使用torch.save()函数将Tensor数据保存到文件中。以下是保存和加载Tensor数据的示例代码:...
    99+
    2024-04-02
  • pytorch 如何查看数据类型和大小
    问题描述: 查看tensor数据大小时使用了data.shape(),报错: TypeError: 'torch.Size' object is not callable 或 Ty...
    99+
    2024-04-02
  • 如何将数据按指定格式存入zookeeper
    这篇文章主要讲解了“如何将数据按指定格式存入zookeeper”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“如何将数据按指定格式存入zookeeper”吧!环境:  scala版本...
    99+
    2023-06-02
  • C++中如何将数据保存为CSV文件
    目录C++将数据保存为CSV文件如何存储CSV文件C++将数据保存为CSV文件 因为最近涉及到保存模型推理结果的输出文件,所以学一学如何将数据保存为CSV文件,比如保存检测框box的...
    99+
    2022-11-16
    C++ CSV文件 数据保存为CSV文件 C++ 数据保存
  • php如何将数据类型转换为字符串类型
    今天小编给大家分享一下php如何将数据类型转换为字符串类型的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。转换方法:1、使用s...
    99+
    2023-06-29
  • vue中如何将数据转为int类型
    Vue是一款流行的JavaScript框架,被广泛应用于Web前端开发。在Vue的开发过程中,经常会遇到数据类型转换问题,尤其是将字符串转换为整型。本文将介绍如何在Vue中将数据转为int类型。一、使用parseInt函数转换parseIn...
    99+
    2023-05-14
  • 数据库中blob类型如何存取
    在数据库中存取blob类型的数据,可以使用以下方法:1. 通过编程语言的API将blob数据写入数据库。大多数编程语言都提供了API...
    99+
    2023-09-21
    数据库
  • Prometheus数据存储如何指定类型
    在Prometheus中,数据存储的类型由Metric的名称和标签来指定。每个Metric都有一个名称和一组标签,用来唯一标识该Me...
    99+
    2024-03-14
    prometheus
  • php如何将json数据转化为数组类型
    这篇文章主要讲解了“php如何将json数据转化为数组类型”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“php如何将json数据转化为数组类型”吧!在php中,可以利用json_decode...
    99+
    2023-06-22
  • Pytorch如何继承Subset类完成自定义数据拆分
    这篇文章主要介绍“Pytorch如何继承Subset类完成自定义数据拆分”,在日常操作中,相信很多人在Pytorch如何继承Subset类完成自定义数据拆分问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”Pyt...
    99+
    2023-06-29
  • ASP 数据类型学习笔记:如何存储数据?
    在 ASP 中,数据类型是非常重要的概念。正确的使用数据类型可以提高程序的运行效率和安全性。本文将介绍 ASP 中常用的数据类型以及如何存储数据。 一、数据类型 ASP 支持多种数据类型,下面是 ASP 中常用的数据类型: 字符串类型(...
    99+
    2023-10-16
    学习笔记 存储 数据类型
  • 如何在 PHP 中存储 NumPy 数据类型?
    在 PHP 中存储 NumPy 数据类型是一个常见的需求,因为 NumPy 提供了很多高效的数学运算和数据处理功能,而 PHP 是一种常用的服务器端编程语言,它可以方便地与 Web 应用程序集成。在本文中,我们将介绍如何在 PHP 中存储 ...
    99+
    2023-10-02
    数据类型 存储 numpy
  • 如何在 ASP 中存储 numy 数据类型?
    ASP 是一种流行的服务器端脚本语言,它的运行环境是 Microsoft Windows。在 ASP 中,我们可以使用多种数据类型来存储数据,包括数字、字符串、布尔值和日期等。但是,如果我们需要在 ASP 中存储 numy 数据类型,该怎么...
    99+
    2023-08-03
    存储 numy 数据类型
  • 基于Android如何实现将数据库保存到SD卡
    有时候为了需要,会将数据库保存到外部存储或者SD卡中(对于这种情况可以通过加密数据来避免数据被破解),比如一个应用支持多个数据,每个数据都需要有一个对应的数据库,并且数据库中的...
    99+
    2022-06-06
    sd sd卡 数据库 数据 Android
  • 如何利用Log Parser将IIS日志保存到数据库
    这篇文章主要讲解了“如何利用Log Parser将IIS日志保存到数据库”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“如何利用Log Parser将IIS日志保存到数据库”吧!一个小时把上一...
    99+
    2023-06-19
  • java中如何使double类型数据保留两位小数
    方式一:保留两位小数DecimalFormat df = new DecimalFormat("#.00"); double d1 = 1.23456 double d2 = 2.0; double d3 = 0.0; S...
    99+
    2019-05-04
    java double 小数 两位
  • java中如何使float类型数据保留两位小数
    方法1:用Math.round计算,这里返回的数字格式的float price=89.89; int itemNum=3; float totalPrice=price*itemNum; float num=(float)(Math.rou...
    99+
    2018-02-14
    java基础 java float 两位 小数
  • php如何将字符串转double类型并保留两位小数
    这篇“php如何将字符串转double类型并保留两位小数”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“php如何将字符串转d...
    99+
    2023-06-30
  • MySQL中BIGINT数据类型如何存储整数值
    目录前言mysql BIGINT例子示例 1示例 2示例 3示例 4结论前言 本文重点介绍 MySQL BIGINT 数据类型,并研究我们如何使用它来存储整数值。我们还将了解它的范围、存储大小和各种属性,包括有符号、无符...
    99+
    2024-04-02
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作