当前位置 博文首页 > deeplearningfeng的博客:tensorflow mnist入门

    deeplearningfeng的博客:tensorflow mnist入门

    作者:[db:作者] 时间:2021-09-19 22:36

    人工神经网络

    import tensorflow as tf

    import numpy as np
    import input_data//如果出现找不到input_data module 错误,可以将python -tensorflow-mnist路径下的input_data.py复制到当前的.py文件路径下
    mnist = input_data.read_data_sets("home/mnist", one_hot=True)//home/mnist是存放下载数据集的路径,根据个人情况加以更改

    x = tf.placeholder("float", [None, 784])
    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))

    y = tf.nn.softmax(tf.matmul(x,W) + b)


    y_ = tf.placeholder("float", [None,10])
    cross_entropy = -tf.reduce_sum(y_*tf.log(y))

    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    init = tf.initialize_all_variables()

    sess = tf.Session()
    sess.run(init)

    for i in range(1000):
    ? batch_xs, batch_ys = mnist.train.next_batch(100)
    ? sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})cs