当前位置 博文首页 > hallobike的博客:使用tf.data数据转换来训练MNIST数据集

    hallobike的博客:使用tf.data数据转换来训练MNIST数据集

    作者:[db:作者] 时间:2021-09-14 22:00

    以MNIST数据集为例来训练模型

    # -*- coding: UTF-8 -*-
    """
    Author: LGD
    FileName: fashion_mnist_tfdataset
    DateTime: 2020/11/26 09:04 
    SoftWare: PyCharm
    """
    import tensorflow as tf
    
    print('Tensorflow version: {}'.format(tf.__version__))
    
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
    
    # 数据归一化
    train_images = train_images / 255
    test_images = test_images / 255
    
    # 建立train_images的Dataset
    ds_train_img = tf.data.Dataset.from_tensor_slices(train_images)
    print(ds_train_img)
    ds_train_label = tf.data.Dataset.from_tensor_slices(train_labels)
    print(ds_train_label)
    
    # 使用zip将数据合并到一起
    ds_train = tf.data.Dataset.zip((ds_train_img, ds_train_label))
    print(ds_train)
    
    # 对数据做变换,取出10000组数据乱序,循环,分批次,每批次数据量为64
    ds_train = ds_train.shuffle(10000).repeat().batch(64)
    
    # 建立模型
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    
    # 编译模型
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # # 训练
    # steps_per_epochs = train_images.shape[0] // 64  # 每次迭代64张图片,每个epoch迭代的步数
    # model.fit(
    #     ds_train,
    #     epochs=5,
    #     steps_per_epoch=steps_per_epochs
    # )
    
    # 建立test_images的Dataset
    ds_test = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
    ds_test = ds_test.batch(64)
    
    # 训练
    steps_per_epochs = train_images.shape[0] // 64  # 每次迭代64张图片,每个epoch迭代的步数
    model.fit(
        ds_train,
        epochs=5,
        steps_per_epoch=steps_per_epochs,
        validation_data=ds_test,
        validation_steps=10000//64  # 由于有循环,必须要有step它才知道什么时候打印一下验证准确率。
    )
    

    获取MNIST数据集,可以直接是在代码加载里下载,也可以关注下列公众号加读者微信,分享百度网盘链接。
    在这里插入图片描述

    cs
    下一篇:没有了