当前位置 主页 > 网站技术 > 代码类 >

    使用PyTorch训练一个图像分类器实例

    栏目:代码类 时间:2020-01-08 18:07

    如下所示:

    import torch
    import torchvision
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt
    import numpy as np
    
    print("torch: %s" % torch.__version__)
    print("tortorchvisionch: %s" % torchvision.__version__)
    print("numpy: %s" % np.__version__)
    
    

    Out:

    torch: 1.0.0
    tortorchvisionch: 0.2.1
    numpy: 1.15.4

    数据从哪儿来?

    通常来说,你可以通过一些python包来把图像、文本、音频和视频数据加载为numpy array。然后将其转换为torch.*Tensor。

    图像。Pillow、OpenCV是用得比较多的

    音频。scipy和librosa

    文本。纯Python或者Cython就可以完成数据加载,可以在NLTK和SpaCy找到数据

    对于计算机视觉而言,我们有torchvision包,它可以用来加载一下常用数据集如Imagenet、CIFAR10、MINIST等等,也有一些常用的为图像准备数据转换例如torchvision.datasets和torch.utils.data.DataLoader。

    这次的教程中,我们使用CIFAR10数据集,他有‘airplane', ‘automobile', ‘bird', ‘cat', ‘deer', ‘dog', ‘frog', ‘horse', ‘ship', ‘truck'这几个类别的图像。图像大小都是3x32x32的。也就是说,图像都是三通道的,每一张图的尺寸都是32x32。

    训练一个图像分类器

    步骤如下:

    使用torchvision加载、归一化训练集和测试集

    定义卷积神经网络

    定义损失函数

    使用训练集训练网络

    使用测试集测试网络

    1. 加载、归一化CIFAR10

    我们可以使用torchvision很轻松的完成

    torchvision的数据集是基于PILImage的,数值是[0, 1],我们需要将其转成范围为[-1, 1]的Tensor

    transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                        download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
                         shuffle=True, num_workers=4)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, 
                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4, 
                         shuffle=True, num_workers=4)
    classes = ('plane', 'car', 'bird', 'cat', 
          'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    

    Out:

    Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
    Files already downloaded and verified

    让我们来看看训练集的图片

    # 显示一张图片
    def imshow(img):
      img = img / 2 + 0.5   # 逆归一化
      npimg = img.numpy()
      plt.imshow(np.transpose(npimg, (1, 2, 0)))
      plt.show()
    
    
    # 任意地拿到一些图片
    dataiter = iter(trainloader)
    images, labels = dataiter.next()
    
    # 显示图片
    imshow(torchvision.utils.make_grid(images))
    # 显示类标
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
    

    Out:

    truck  dog ship  dog