Git Product home page Git Product logo

mnist's Introduction

MNIST

MNIST数据集

1、下载下来的数据集被分成三部分:55000行的训练数据集(mnist.train)、5000行验证集(mnist.validation)和10000行的测试数据集(mnist.test)

2、每一个mnist数据单元,包括图片和标签两部分:mnist.train.images和mnist.train.labels

3、mnist.train.images的shape(60000,24*24),每个元素的值介于0-1之间

4、mnist.train.labels的shape(60000,10),ont_hot编码

5、DataSet.next_batch(batch_size)是用于获取以batch_size为大小的一个元组,其中包含了一组图片和标签

from tensorflow.examples.tutorials.mnist import input_data
# 将数据集下载到"MNIST-data"文件中
mnist = input_data.read_data_sets("MNIST-data", one_hot=True)

网络结构

使用tf.layers来构建神经网络

网络结构: conv ->pool ->conv ->pool ->fc ->dropout ->fc

       # build network
        input_ = tf.reshape(self.images_placeholder, [-1, 28, 28, 1])
        net = tf.layers.conv2d(input_, 32, 5, padding="same", activation=tf.nn.relu, name="conv1")
        net = tf.layers.max_pooling2d(net, 2, 2, name="pool1")

        net = tf.layers.conv2d(net, 64, 5, padding="same", activation=tf.nn.relu, name="conv2")
        net = tf.layers.max_pooling2d(net, 2, 2, name="pool2")

        net = tf.layers.flatten(net, name="flatten")
        net = tf.layers.dense(net, 1024, activation=tf.nn.relu, name="fc1")
        net = tf.layers.dropout(net, rate=0.4, name="dropout")
        logits = tf.layers.dense(net, 10, name="fc2")

训练部分

此问题为多分类问题,故使用交叉熵损失函数

     # 定义loss函数
        loss = tf.nn.softmax_cross_entropy_with_logits(
            labels=self.labels_placeholder,
            logits=logits)
        self.mean_loss = tf.reduce_mean(loss)
        tf.summary.scalar("loss", self.mean_loss)

        # 梯度下降算法选择
        self.global_step = tf.train.create_global_step()
        self.train_op = tf.train.AdamOptimizer().minimize(self.mean_loss, global_step=self.global_step)

模型评估

使用tf.metrics.accuracy来评估模型

  	    # accuracy 模型评估指标
        self.predict_label = tf.argmax(logits, axis=1)
        self.accuracy = tf.metrics.accuracy(
            labels=tf.argmax(self.labels_placeholder, axis=1),
            predictions=self.predict_label)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.