当前位置 主页 > 网站技术 > 代码类 >

    浅谈tensorflow中Dataset图片的批量读取及维度的操作详解

    栏目:代码类 时间:2020-01-20 18:07

    三维的读取图片(w, h, c):

    import tensorflow as tf
     
    import glob
    import os
     
     
    def _parse_function(filename):
      # print(filename)
      image_string = tf.read_file(filename)
      image_decoded = tf.image.decode_image(image_string) # (375, 500, 3)
     
      image_resized = tf.image.resize_image_with_crop_or_pad(image_decoded, 200, 200)
      return image_resized
     
     
     
     
    with tf.Session() as sess:
     
      print( sess.run( img ).shape  )

    读取批量图片的读取图片(b, w, h, c):

    import tensorflow as tf
     
    import glob
    import os
     
    '''
      Dataset 批量读取图片
    '''
     
    def _parse_function(filename):
      # print(filename)
      image_string = tf.read_file(filename)
      image_decoded = tf.image.decode_image(image_string) # (375, 500, 3)
     
      image_decoded = tf.expand_dims(image_decoded, axis=0)
     
      image_resized = tf.image.resize_image_with_crop_or_pad(image_decoded, 200, 200)
      return image_resized
     
     
     
    img = _parse_function('../pascal/VOCdevkit/VOC2012/JPEGImages/2007_000068.jpg')
     
    # image_resized = tf.image.resize_image_with_crop_or_pad( tf.truncated_normal((1,220,300,3))*10, 200, 200) 这种四维 形式是可以的
     
    with tf.Session() as sess:
     
      print( sess.run( img ).shape  ) #直接初始化就可以 ,转换成四维报错误,不知道为什么,若谁想明白,请留言 报错误
      #InvalidArgumentError (see above for traceback): Input shape axis 0 must equal 4, got shape [5]

    Databae的操作:

    import tensorflow as tf
     
    import glob
    import os
     
    '''
      Dataset 批量读取图片:
      
        原因:
          1. 先定义图片名的list,存放在Dataset中 from_tensor_slices()
          2. 映射函数, 在函数中,对list中的图片进行读取,和resize,细节
            tf.read_file(filename) 返回的是三维的,因为这个每次取出一张图片,放进队列中的,不需要转化为四维
            然后对图片进行resize, 然后每个batch进行访问这个函数 ,所以get_next() 返回的是 [batch, w, h, c ]
          3. 进行shuffle , batch repeat的设置
          
          4. iterator = dataset.make_one_shot_iterator() 设置迭代器
          
          5. iterator.get_next() 获取每个batch的图片
    '''
     
    def _parse_function(filename):
      # print(filename)
      image_string = tf.read_file(filename)
      image_decoded = tf.image.decode_image(image_string) #(375, 500, 3)
      '''
        Tensor` with type `uint8` with shape `[height, width, num_channels]` for
         BMP, JPEG, and PNG images and shape `[num_frames, height, width, 3]` for
         GIF images.
      '''
     
      # image_resized = tf.image.resize_images(label, [200, 200])
      ''' images 三维,四维的都可以
         images: 4-D Tensor of shape `[batch, height, width, channels]` or
          3-D Tensor of shape `[height, width, channels]`.
        size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
           new size for the images.
      
      '''
      image_resized = tf.image.resize_image_with_crop_or_pad(image_decoded, 200, 200)
     
      # return tf.squeeze(mage_resized,axis=0)
      return image_resized
     
    filenames = glob.glob( os.path.join('../pascal/VOCdevkit/VOC2012/JPEGImages', "*." + 'jpg') )
     
     
    dataset = tf.data.Dataset.from_tensor_slices((filenames))
     
    dataset = dataset.map(_parse_function)
     
    dataset = dataset.shuffle(10).batch(2).repeat(10)
    iterator = dataset.make_one_shot_iterator()
     
    img = iterator.get_next()
     
    with tf.Session() as sess:
      # print( sess.run(img).shape ) #(4, 200, 200, 3)
      for _ in range (10):
        print( sess.run(img).shape )