当前位置 主页 > 网站技术 > 代码类 >

    tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用

    栏目:代码类 时间:2020-01-20 12:08

    1.创建tfrecord

    tfrecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList、tf.train.Int64List、tf.train.FloatList写入tf.train.Feature,如下所示:

    tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) #feature一般是多维数组,要先转为list
    tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape))) #tostring函数后feature的形状信息会丢失,把shape也写入
    tf.train.Feature(float_list=tf.train.FloatList(value=[label]))

    通过上述操作,以dict的形式把要写入的数据汇总,并构建tf.train.Features,然后构建tf.train.Example,如下:

    def get_tfrecords_example(feature, label):
     tfrecords_features = {}
     feat_shape = feature.shape
     tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
     tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
     tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
     return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))

    把创建的tf.train.Example序列化下,便可通过tf.python_io.TFRecordWriter写入tfrecord文件,如下:

    tfrecord_wrt = tf.python_io.TFRecordWriter('xxx.tfrecord') #创建tfrecord的writer,文件名为xxx
    exmp = get_tfrecords_example(feats[inx], labels[inx]) #把数据写入Example
    exmp_serial = exmp.SerializeToString()  #Example序列化
    tfrecord_wrt.write(exmp_serial)  #写入tfrecord文件
    tfrecord_wrt.close()  #写完后关闭tfrecord的writer

    代码汇总:

    import tensorflow as tf
    from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
     
    mnist = read_data_sets("MNIST_data/", one_hot=True)
    #把数据写入Example
    def get_tfrecords_example(feature, label):
     tfrecords_features = {}
     feat_shape = feature.shape
     tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
     tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
     tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
     return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
    #把所有数据写入tfrecord文件
    def make_tfrecord(data, outf_nm='mnist-train'):
     feats, labels = data
     outf_nm += '.tfrecord'
     tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
     ndatas = len(labels)
     for inx in range(ndatas):
     exmp = get_tfrecords_example(feats[inx], labels[inx])
     exmp_serial = exmp.SerializeToString()
     tfrecord_wrt.write(exmp_serial)
     tfrecord_wrt.close()
     
    import random
    nDatas = len(mnist.train.labels)
    inx_lst = range(nDatas)
    random.shuffle(inx_lst)
    random.shuffle(inx_lst)
    ntrains = int(0.85*nDatas)
     
    # make training set
    data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
     [mnist.train.labels[i] for i in inx_lst[:ntrains]])
    make_tfrecord(data, outf_nm='mnist-train')
     
    # make validation set
    data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
     [mnist.train.labels[i] for i in inx_lst[ntrains:]])
    make_tfrecord(data, outf_nm='mnist-val')
     
    # make test set
    data = (mnist.test.images, mnist.test.labels)
    make_tfrecord(data, outf_nm='mnist-test')