当前位置 主页 > 网站技术 > 代码类 >

    tensorflow使用range

    栏目:代码类 时间:2020-01-20 21:10

    先放关键代码:

    i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
    inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])

    原理解析:

    第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;

    0,1,2,0,1,2

    队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。

    如果num_epochs不指定,则队列内容是这样子:

    0,1,2,0,1,2,0,1,2,0,1,2...

    队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。

    下面是完整的演示代码。

    数据文件test.txt内容:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    

    main.py内容:

    import tensorflow as tf
    import codecs
     
    BATCH_SIZE = 6
    NUM_EXPOCHES = 5
     
     
    def input_producer():
     array = codecs.open("test.txt").readlines()
    	array = map(lambda line: line.strip(), array)
     i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
     inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
     return inputs
     
     
    class Inputs(object):
     def __init__(self):
      self.inputs = input_producer()
     
     
    def main(*args, **kwargs):
     inputs = Inputs()
     init = tf.group(tf.initialize_all_variables(),
         tf.initialize_local_variables())
     sess = tf.Session()
     coord = tf.train.Coordinator()
     threads = tf.train.start_queue_runners(sess=sess, coord=coord)
     sess.run(init)
     try:
      index = 0
      while not coord.should_stop() and index<10:
       datalines = sess.run(inputs.inputs)
       index += 1
       print("step: %d, batch data: %s" % (index, str(datalines)))
     except tf.errors.OutOfRangeError:
      print("Done traing:-------Epoch limit reached")
     except KeyboardInterrupt:
      print("keyboard interrput detected, stop training")
     finally:
      coord.request_stop()
     coord.join(threads)
     sess.close()
     del sess
    	
    if __name__ == "__main__":
     main()
    

    输出:

    step: 1, batch data: ['1' '2' '3' '4' '5' '6']
    step: 2, batch data: ['7' '8' '9' '10' '11' '12']
    step: 3, batch data: ['13' '14' '15' '16' '17' '18']
    step: 4, batch data: ['19' '20' '21' '22' '23' '24']
    step: 5, batch data: ['25' '26' '27' '28' '29' '30']
    Done traing:-------Epoch limit reached

    如果range_input_producer去掉参数num_epochs=1,则输出:

    step: 1, batch data: ['1' '2' '3' '4' '5' '6']
    step: 2, batch data: ['7' '8' '9' '10' '11' '12']
    step: 3, batch data: ['13' '14' '15' '16' '17' '18']
    step: 4, batch data: ['19' '20' '21' '22' '23' '24']
    step: 5, batch data: ['25' '26' '27' '28' '29' '30']
    step: 6, batch data: ['1' '2' '3' '4' '5' '6']
    step: 7, batch data: ['7' '8' '9' '10' '11' '12']
    step: 8, batch data: ['13' '14' '15' '16' '17' '18']
    step: 9, batch data: ['19' '20' '21' '22' '23' '24']
    step: 10, batch data: ['25' '26' '27' '28' '29' '30']