深度学习初学者,如何下载常用公开数据集并使用呢?

1.前言2.官方文档怎样看3.动手写代码4.如何可视化遇到问题:ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1131)

1.前言

刚开始进行深度学习的时候,难免要用到一些公开数据集,现在闲来无事,记录一下如何快速下载一些经典数据集。通过官方文档学习,是一些大牛们挂在嘴边经常推荐的方法,那么我们本篇博客就从官方文档开始学习。

因为我是做CV方向的,所以用TorchVision这个库举例。来自官网:This library is part of the [PyTorch](http://pytorch.org/) project. PyTorch is an open source machine learning framework.

The [torchvision] package consists of popular datasets, model architectures, and common image transformations for computer vision.

包括很多流行数据集,如我们常见的CIFAR,COCO和MINST,大家应该都不陌生。一会儿会以CIFAR举例,记录一下我的过程。

2.官方文档怎样看

首先我们看一下CIFAR这个类的文档:

参数:

root:表示将下载的数据集放在哪个目录

root (string): Root directory of dataset where directory ``cifar-10-batches-py`` exists or will be saved to if download is set to True.

train:是否为训练数据集

train (bool, optional): If True, creates dataset from training set, otherwise creates from test set.

transform:一个将图像进行预处理、返回transform的函数

A function/transform that takes in an PIL image and returns a transformed version.

download:是否下载数据集,

download (bool, optional):If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

3.动手写代码

示例代码

# 导入torchvision包

import torchvision

# 对原始图像进行数据处理的函数

dataset_transform = torchvision.transforms.Compose([

torchvision.transforms.ToTensor()

])

# 生成训练数据集和测试数据集

# 训练数据集 存放在根目录的dataset文件夹下,作为训练数据集,并下载

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)

# 测试数据集 存放在根目录的dataset文件夹下,不作为训练数据集,并下载

test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

print(test_set[0])

然后我们右键运行,进行下载

可以看到数据集已经开始下载了,但是因为是从toronto.edu下载,速度很慢。教你一个更快的方法:我们终止运行,复制这个链接,用迅雷下载,很快就好了。然后将下载好的.gz文件进行解压,放到我们创建的dataset目录下:

重新run,就可以正常使用数据集了。

4.如何可视化

我用tensorboard进行了可视化,大家感兴趣可以研究一下tensorboard这个库。

import torchvision

from torch.utils.tensorboard import SummaryWriter

import ssl

ssl._create_default_https_context = ssl._create_unverified_context

dataset_transform = torchvision.transforms.Compose([

torchvision.transforms.ToTensor()

])

# 返回类型

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)

test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

print(test_set[0])

writer = SummaryWriter("p10")

for i in range(10):

img, target = test_set[i]

writer.add_image("test_set", img, i)

writer.close()

在浏览器上就可以看到图像啦:

遇到问题:ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1131)

如果在下载中遇到同样的问题,需要导入ssl:

import ssl

ssl._create_default_https_context = ssl._create_unverified_context

说在最后的话:编写实属不易,若喜欢或者对你有帮助记得点赞 + 关注或者收藏哦~