当前位置 博文首页 > Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作

    Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作

    作者:pan_jinquan 时间:2021-07-19 18:38

    【源码GitHub地址】:点击进入

    1. 问题描述

    之前写了一篇关于《pytorch Dataset, DataLoader产生自定义的训练数据》的博客,但存在一个问题,我们不能在Dataset做一些数据清理,如果我们传递给Dataset数据,本身存在问题,那么迭代过程肯定出错的。

    比如我把很多图片路径都传递给Dataset,如果图片路径都是正确的,且图片都存在也没有损坏,那显然运行是没有问题的;

    但倘若传递给Dataset的图片路径有些图片是不存在,这时你通过Dataset读取图片数据,然后再迭代返回,就会出现类似如下的错误:

    File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 68, in <listcomp> return [default_collate(samples) for samples in transposed]

    File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 70, in default_collate

    raise TypeError((error_msg_fmt.format(type(batch[0])))) TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'NoneType'>

    2. 一般的解决方法

    一般的解决方法也很简单粗暴,就是在传递数据给Dataset前,就做数据清理,把不存在的图片,损坏的数据都提前清理掉。

    是的,这个是最简单粗暴的。

    3. 另一种解决方法:自定义返回数据的规则:collate_fn()校对函数

    我们希望不管传递什么处理给Dataset,Dataset都进行处理,如果不存在或者异常,就返回None,而在DataLoader时,对于不存为None的数据,都去除掉。

    这样就保证在迭代过程中,DataLoader获得batch数据都是正确的。

    比如读取batch_size=5的图片数据,如果其中有1个(或者多个)图片是不存在,那么返回的batch应该把不存在的数据过滤掉,即返回5-1=4大小的batch的数据。

    是的,我要实现的就是这个功能:返回的batch数据会自定清理掉不合法的数据。

    3.1 Pytorch数据处理函数:Dataset和 DataLoader

    Pytorch有两个数据处理函数:Dataset和 DataLoader

    from torch.utils.data import Dataset, DataLoader

    其中Dataset用于定义数据的读取和预处理操作,而DataLoader用于加载并产生批训练数据。

    torch.utils.data.DataLoader参数说明:

    DataLoader(object)可用参数:

    1、dataset(Dataset) 传入的数据集

    2、batch_size(int, optional) 每个batch有多少个样本

    3、shuffle(bool, optional) 在每个epoch开始的时候,对数据进行重新排序

    4、sampler(Sampler, optional) 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

    5、batch_sampler(Sampler, optional) 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

    6、num_workers (int, optional) 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

    7、collate_fn (callable, optional) 将一个list的sample组成一个mini-batch的函数

    8、pin_memory (bool, optional) 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

    9、drop_last (bool, optional) 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了。 如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

    10、timeout(numeric, optional) 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

    11、worker_init_fn (callable, optional) 每个worker初始化函数 If not None, this will be called on eachworker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

    我们要用到的是collate_fn()回调函数

    3.2 自定义collate_fn()函数:

    torch.utils.data.DataLoader的collate_fn()用于设置batch数据拼接方式,默认是default_collate函数,但当batch中含有None等数据时,默认的default_collate校队方法会出现错误。因此,我们需要自定义collate_fn()函数:

    方法也很简单:只需在原来的default_collate函数中添加下面几句代码:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了。

     # 这里添加:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
     if isinstance(batch, list):
     batch = [(image, image_id) for (image, image_id) in batch if image is not None]
     if batch==[]:
     return (None,None)

    dataset_collate.py:

    # -*-coding: utf-8 -*-
    """
     @Project: pytorch-learning-tutorials
     @File : dataset_collate.py
     @Author : panjq
     @E-mail : pan_jinquan@163.com
     @Date : 2019-06-07 17:09:13
    """
     
    r""""Contains definitions of the methods used by the _DataLoaderIter workers to
    collate samples fetched from dataset into Tensor(s).
    These **needs** to be in global scope since Py2 doesn't support serializing
    static methods.
    """
    import torch
    import re
    from torch._six import container_abcs, string_classes, int_classes 
    _use_shared_memory = False
    r"""Whether to use shared memory in default_collate"""
     
    np_str_obj_array_pattern = re.compile(r'[SaUO]')
     
    error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}"
     
    numpy_type_map = {
     'float64': torch.DoubleTensor,
     'float32': torch.FloatTensor,
     'float16': torch.HalfTensor,
     'int64': torch.LongTensor,
     'int32': torch.IntTensor,
     'int16': torch.ShortTensor,
     'int8': torch.CharTensor,
     'uint8': torch.ByteTensor,
    }
     
    def collate_fn(batch):
     '''
     collate_fn (callable, optional): merges a list of samples to form a mini-batch.
     该函数参考touch的default_collate函数,也是DataLoader的默认的校对方法,当batch中含有None等数据时,
     默认的default_collate校队方法会出现错误
     一种的解决方法是:
     判断batch中image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
     :param batch:
     :return:
     '''
     r"""Puts each data field into a tensor with outer dimension batch size"""
     # 这里添加:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
     if isinstance(batch, list):
     batch = [(image, image_id) for (image, image_id) in batch if image is not None]
     if batch==[]:
     return (None,None)
     
     elem_type = type(batch[0])
     if isinstance(batch[0], torch.Tensor):
     out = None
     if _use_shared_memory:
      # If we're in a background process, concatenate directly into a
      # shared memory tensor to avoid an extra copy
      numel = sum([x.numel() for x in batch])
      storage = batch[0].storage()._new_shared(numel)
      out = batch[0].new(storage)
     return torch.stack(batch, 0, out=out)
     elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
      and elem_type.__name__ != 'string_':
     elem = batch[0]
     if elem_type.__name__ == 'ndarray':
      # array of string classes and object
      if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
      raise TypeError(error_msg_fmt.format(elem.dtype))
     
      return collate_fn([torch.from_numpy(b) for b in batch])
     if elem.shape == (): # scalars
      py_type = float if elem.dtype.name.startswith('float') else int
      return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
     elif isinstance(batch[0], float):
     return torch.tensor(batch, dtype=torch.float64)
     elif isinstance(batch[0], int_classes):
     return torch.tensor(batch)
     elif isinstance(batch[0], string_classes):
     return batch
     elif isinstance(batch[0], container_abcs.Mapping):
     return {key: collate_fn([d[key] for d in batch]) for key in batch[0]}
     elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple
     return type(batch[0])(*(collate_fn(samples) for samples in zip(*batch)))
     elif isinstance(batch[0], container_abcs.Sequence):
     transposed = zip(*batch)#ok
     return [collate_fn(samples) for samples in transposed]
     
     raise TypeError((error_msg_fmt.format(type(batch[0]))))
    

    测试方法:

    # -*-coding: utf-8 -*-
    """
     @Project: pytorch-learning-tutorials
     @File : dataset.py
     @Author : panjq
     @E-mail : pan_jinquan@163.com
     @Date : 2019-03-07 18:45:06
    """
    import torch
    from torch.autograd import Variable
    from torchvision import transforms
    from torch.utils.data import Dataset, DataLoader
    import numpy as np
    from utils import dataset_collate
    import os
    import cv2
    from PIL import Image
    def read_image(path,mode='RGB'):
     '''
     :param path:
     :param mode: RGB or L
     :return:
     '''
     return Image.open(path).convert(mode)
     
    class TorchDataset(Dataset):
     def __init__(self, image_id_list, image_dir, resize_height=256, resize_width=256, repeat=1, transform=None):
     '''
     :param filename: 数据文件TXT:格式:imge_name.jpg label1_id labe2_id
     :param image_dir: 图片路径:image_dir+imge_name.jpg构成图片的完整路径
     :param resize_height 为None时,不进行缩放
     :param resize_width 为None时,不进行缩放,
        PS:当参数resize_height或resize_width其中一个为None时,可实现等比例缩放
     :param repeat: 所有样本数据重复次数,默认循环一次,当repeat为None时,表示无限循环<sys.maxsize
     :param transform:预处理
     '''
     self.image_dir = image_dir
     self.image_id_list=image_id_list
     self.len = len(image_id_list)
     self.repeat = repeat
     self.resize_height = resize_height
     self.resize_width = resize_width
     self.transform= transform
     
     def __getitem__(self, i):
     index = i % self.len
     # print("i={},index={}".format(i, index))
     image_id = self.image_id_list[index]
     image_path = os.path.join(self.image_dir, image_id)
     img = self.load_data(image_path)
     
     if img is None:
      return None,image_id
     img = self.data_preproccess(img)
     return img,image_id
     
     def __len__(self):
     if self.repeat == None:
      data_len = 10000000
     else:
      data_len = len(self.image_id_list) * self.repeat
     return data_len
     
     def load_data(self, path):
     '''
     加载数据
     :param path:
     :param resize_height:
     :param resize_width:
     :param normalization: 是否归一化
     :return:
     '''
     try:
      image = read_image(path)
     except Exception as e:
      image=None
      print(e)
     # image = image_processing.read_image(path)#用opencv读取图像
     return image
     
     def data_preproccess(self, data):
     '''
     数据预处理
     :param data:
     :return:
     '''
     if self.transform is not None:
      data = self.transform(data)
     return data
     
    if __name__=='__main__':
     
     resize_height = 224
     resize_width = 224
     image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"]
     image_dir="../dataset/test_images/images"
     # 相关预处理的初始化
     '''class torchvision.transforms.ToTensor把shape=(H,W,C)的像素值范围为[0, 255]的PIL.Image或者numpy.ndarray数据
     # 转换成shape=(C,H,W)的像素数据,并且被归一化到[0.0, 1.0]的torch.FloatTensor类型。
     '''
     train_transform = transforms.Compose([
     transforms.Resize(size=(resize_height, resize_width)),
     # transforms.RandomHorizontalFlip(),#随机翻转图像
     transforms.RandomCrop(size=(resize_height, resize_width), padding=4), # 随机裁剪
     transforms.ToTensor(), # 吧shape=(H,W,C)->换成shape=(C,H,W),并且归一化到[0.0, 1.0]的torch.FloatTensor类型
     # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#给定均值(R,G,B) 方差(R,G,B),将会把Tensor正则化
     ])
     
     epoch_num=2 #总样本循环次数
     batch_size=5 #训练时的一组数据的大小
     train_data_nums=10
     max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num) #总迭代次数
     
     train_data = TorchDataset(image_id_list=image_id_list,
        image_dir=image_dir,
        resize_height=resize_height,
        resize_width=resize_width,
        repeat=1,
        transform=train_transform)
     # 使用默认的default_collate会报错
     # train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
     # 使用自定义的collate_fn
     train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False, collate_fn=dataset_collate.collate_fn)
     
     
     # [1]使用epoch方法迭代,TorchDataset的参数repeat=1
     for epoch in range(epoch_num):
     for step,(batch_image, batch_label) in enumerate(train_loader):
      if batch_image is None and batch_label is None:
      print("batch_image:{},batch_label:{}".format(batch_image, batch_label))
      continue
      image=batch_image[0,:]
      image=image.numpy()#image=np.array(image)
      image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]
      cv2.imshow("image",image)
      cv2.waitKey(2000)
      print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))
      # batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    

    输出结果说明:

    batch_size=5,输入图片列表image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"] ,其中"ddd.jpg","111.jpg"是不存在的,resize_width=224,正常情况下返回的数据应该是torch.Size([5, 3, 224, 224]),但由于"ddd.jpg","111.jpg"不存在,被过滤掉了,所以第一个batch的维度变为torch.Size([3, 3, 224, 224])

    [Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'

    [Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'

    batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')

    batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')

    [Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'

    [Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'

    batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')

    batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持站长博客。如有错误或未考虑完全的地方,望不吝赐教。

    jsjbwy
    下一篇:没有了