当前位置 主页 > 网站技术 > 代码类 >

    TensorFLow 不同大小图片的TFrecords存取实例

    栏目:代码类 时间:2020-01-20 15:09

    全部存入一个TFrecords文件,然后读取并显示第一张。

    不多写了,直接贴代码。

    from PIL import Image
    import numpy as np
    import matplotlib.pyplot as plt
    import tensorflow as tf
    
    
    IMAGE_PATH = 'test/'
    tfrecord_file = IMAGE_PATH + 'test.tfrecord'
    writer = tf.python_io.TFRecordWriter(tfrecord_file)
    
    
    def _int64_feature(value):
     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    def _bytes_feature(value):
     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    def get_image_binary(filename):
      """ You can read in the image using tensorflow too, but it's a drag
        since you have to create graphs. It's much easier using Pillow and NumPy
      """
      image = Image.open(filename)
      image = np.asarray(image, np.uint8)
      shape = np.array(image.shape, np.int32)
      return shape, image.tobytes() # convert image to raw data bytes in the array.
    
    def write_to_tfrecord(label, shape, binary_image, tfrecord_file):
      """ This example is to write a sample to TFRecord file. If you want to write
      more samples, just use a loop.
      """
      # write label, shape, and image content to the TFRecord file
      example = tf.train.Example(features=tf.train.Features(feature={
            'label': _int64_feature(label),
            'h': _int64_feature(shape[0]),
            'w': _int64_feature(shape[1]),
            'c': _int64_feature(shape[2]),
            'image': _bytes_feature(binary_image)
            }))
      writer.write(example.SerializeToString())
    
    
    def write_tfrecord(label, image_file, tfrecord_file):
      shape, binary_image = get_image_binary(image_file)
      write_to_tfrecord(label, shape, binary_image, tfrecord_file)
      # print(shape)
    
    
    
    def main():
      # assume the image has the label Chihuahua, which corresponds to class number 1
      label = [1,2]
      image_files = [IMAGE_PATH + 'a.jpg', IMAGE_PATH + 'b.jpg']
    
      for i in range(2):
        write_tfrecord(label[i], image_files[i], tfrecord_file)
      writer.close()
    
      batch_size = 2
    
      filename_queue = tf.train.string_input_producer([tfrecord_file]) 
      reader = tf.TFRecordReader() 
      _, serialized_example = reader.read(filename_queue) 
    
      img_features = tf.parse_single_example( 
                        serialized_example, 
                        features={ 
                            'label': tf.FixedLenFeature([], tf.int64), 
                            'h': tf.FixedLenFeature([], tf.int64),
                            'w': tf.FixedLenFeature([], tf.int64),
                            'c': tf.FixedLenFeature([], tf.int64),
                            'image': tf.FixedLenFeature([], tf.string), 
                            }) 
    
      h = tf.cast(img_features['h'], tf.int32)
      w = tf.cast(img_features['w'], tf.int32)
      c = tf.cast(img_features['c'], tf.int32)
    
      image = tf.decode_raw(img_features['image'], tf.uint8) 
      image = tf.reshape(image, [h, w, c])
    
      label = tf.cast(img_features['label'],tf.int32) 
      label = tf.reshape(label, [1])
    
     # image = tf.image.resize_images(image, (500,500))
      #image, label = tf.train.batch([image, label], batch_size= batch_size) 
    
    
      with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        image, label=sess.run([image, label])
        coord.request_stop()
        coord.join(threads)
    
        print(label)
    
        plt.figure()
        plt.imshow(image)
        plt.show()
    
    
    if __name__ == '__main__':
      main()