当前位置 主页 > 网站技术 > 代码类 >

    tensorflow入门:TFRecordDataset变长数据的batch读取详解

    栏目:代码类 时间:2020-01-20 12:09

    在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

    1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

    2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

    padded_batch(
     batch_size,
     padded_shapes,
     padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
    )

    使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

    import tensorflow as tf
    from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
     
    mnist = read_data_sets("MNIST_data/", one_hot=True)
     
     
    def get_tfrecords_example(feature, label):
     tfrecords_features = {}
     feat_shape = feature.shape
     tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
     tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
     tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
     return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
     
     
    def make_tfrecord(data, outf_nm='mnist-train'):
     feats, labels = data
     outf_nm += '.tfrecord'
     tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
     ndatas = len(labels)
     print(feats[0].dtype, feats[0].shape, ndatas)
     assert len(labels[0]) > 1
     for inx in range(ndatas):
     ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
     exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
     exmp_serial = exmp.SerializeToString()
     tfrecord_wrt.write(exmp_serial)
     tfrecord_wrt.close()
     
    import random
    nDatas = len(mnist.train.labels)
    inx_lst = range(nDatas)
    random.shuffle(inx_lst)
    random.shuffle(inx_lst)
    ntrains = int(0.85*nDatas)
     
    # make training set
    data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
     [mnist.train.labels[i] for i in inx_lst[:ntrains]])
    make_tfrecord(data, outf_nm='mnist-train')
     
    # make validation set
    data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
     [mnist.train.labels[i] for i in inx_lst[ntrains:]])
    make_tfrecord(data, outf_nm='mnist-val')
     
    # make test set
    data = (mnist.test.images, mnist.test.labels)
    make_tfrecord(data, outf_nm='mnist-test')

    用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

    import tensorflow as tf
     
    train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
     
    def parse_exmp(serial_exmp):
     feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
     'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
     image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
     label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
     shape = tf.cast(feats['shape'], tf.int32)
     return image, label, shape
     
    def get_dataset(fname):
     dataset = tf.data.TFRecordDataset(fname)
     return dataset.map(parse_exmp) # use padded_batch method if padding needed
     
    epochs = 16
    batch_size = 50 
    padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
    # training dataset
    dataset_train = get_dataset(train_f)
    dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)