当前位置 博文首页 > pytorch 数据加载性能对比分析

    pytorch 数据加载性能对比分析

    作者:ShellCollector 时间:2021-07-17 18:46

    传统方式需要10s,dat方式需要0.6s

    import os
    import time
    import torch
    import random
    from common.coco_dataset import COCODataset
    def gen_data(batch_size,data_path,target_path):
     os.makedirs(target_path,exist_ok=True)
     dataloader = torch.utils.data.DataLoader(COCODataset(data_path,
                   (352, 352),
                   is_training=False, is_scene=True),
                batch_size=batch_size,
                shuffle=False, num_workers=0, pin_memory=False,
                drop_last=True) # DataLoader
     start = time.time()
     for step, samples in enumerate(dataloader):
      images, labels, image_paths = samples["image"], samples["label"], samples["img_path"]
      print("time", images.size(0), time.time() - start)
      start = time.time()
      # torch.save(samples,target_path+ '/' + str(step) + '.dat')
      print(step)
    def cat_100(target_path,batch_size=100):
     paths = os.listdir(target_path)
     li = [i for i in range(len(paths))]
     random.shuffle(li)
     images = []
     labels = []
     image_paths = []
     start = time.time()
     for i in range(len(paths)):
      samples = torch.load(target_path + str(li[i]) + ".dat")
      image, label, image_path = samples["image"], samples["label"], samples["img_path"]
      images.append(image.cuda())
      labels.append(label.cuda())
      image_paths.append(image_path)
      if i % batch_size == batch_size - 1:
       images = torch.cat((images), 0)
       print("time", images.size(0), time.time() - start)
       images = []
       labels = []
       image_paths = []
       start = time.time()
      i += 1
    if __name__ == '__main__':
     os.environ["CUDA_VISIBLE_DEVICES"] = '3'
     batch_size=320
     # target_path='d:/test_1000/'
     target_path='d:\img_2/'
     data_path = r'D:\dataset\origin_all_datas\_2train'
     gen_data(batch_size,data_path,target_path)
     # get_data(target_path,batch_size)
     # cat_100(target_path,batch_size)
    

    这个读取数据也比较快:320 batch_size 450ms

    def cat_100(target_path,batch_size=100):
     paths = os.listdir(target_path)
     li = [i for i in range(len(paths))]
     random.shuffle(li)
     images = []
     labels = []
     image_paths = []
     start = time.time()
     for i in range(len(paths)):
      samples = torch.load(target_path + str(li[i]) + ".dat")
      image, label, image_path = samples["image"], samples["label"], samples["img_path"]
      images.append(image)#.cuda())
      labels.append(label)#.cuda())
      image_paths.append(image_path)
      if i % batch_size < batch_size - 1:
       i += 1
       continue
      i += 1
      images = torch.cat(([image.cuda() for image in images]), 0)
      print("time", images.size(0), time.time() - start)
      images = []
      labels = []
      image_paths = []
      start = time.time()
    

    补充:pytorch数据加载和处理问题解决方案

    最近跟着pytorch中文文档学习遇到一些小问题,已经解决,在此对这些错误进行记录:

    在读取数据集时报错:

    AttributeError: 'Series' object has no attribute 'as_matrix'

    在显示图片是时报错:

    ValueError: Masked arrays must be 1-D

    显示单张图片时figure一闪而过

    在显示多张散点图的时候报错:

    TypeError: show_landmarks() got an unexpected keyword argument 'image'

    解决方案

    主要问题在这一行: 最终目的是将Series转为Matrix,即调用np.mat即可完成。

    修改前

    landmarks =landmarks_frame.iloc[n, 1:].as_matrix()

    修改后

    landmarks =np.mat(landmarks_frame.iloc[n, 1:])

    打散点的x和y坐标应该均为向量或列表,故将landmarks后使用tolist()方法即可

    修改前

    plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker='.',c='r')

    修改后

    plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')

    前面使用plt.ion()打开交互模式,则后面在plt.show()之前一定要加上plt.ioff()。这里直接加到函数里面,避免每次plt.show()之前都用plt.ioff()

    修改前

    def show_landmarks(imgs,landmarks):
     '''显示带有地标的图片'''
     plt.imshow(imgs)
     plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')#打上红色散点
     plt.pause(1)#绘图窗口延时

    修改后

    def show_landmarks(imgs,landmarks):
     '''显示带有地标的图片'''
     plt.imshow(imgs)
     plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')#打上红色散点
     plt.pause(1)#绘图窗口延时
     plt.ioff()

    网上说对于字典类型的sample可通过 **sample的方式获取每个键下的值,但是会报错,于是把输入写的详细一点,就成功了。

    修改前

    show_landmarks(**sample)

    修改后

    show_landmarks(sample['image'],sample['landmarks'])

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持站长博客。如有错误或未考虑完全的地方,望不吝赐教。

    jsjbwy
    下一篇:没有了