栏目分类:
子分类:
返回
终身学习网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
终身学习网 > IT > 软件开发 > 后端开发 > Python

keras实现mnist识别

Python 更新时间:发布时间: 百科书网 趣学号
1.下载数据集
# 包引入
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
import matplotlib.pyplot as plt
# 1.下载数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

2.案例图像显示
# 案例图像显示
plt.figure(figsize=(20,10))  # 绘制(20,10)图像
for i in range(20):
    plt.subplot(5,10,i+1)  # 5*10子图
    plt.xticks([])  # x,y去标签
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(train_labels[i])
plt.show()

3.预处理
# 归一化
train_images, test_images = train_images / 255.0, test_images / 255.0
# 调整格式
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))
4.CNN建模
# CNN建模
model = Sequential()
# 第一层
# 卷积层1: 卷积核3*3, 共32个
model.add(Conv2D(32,(3,3), activation='relu', input_shape=(28,28,1)))
# 池化层1: 采样2*2
model.add(MaxPooling2D((2,2)))

# 第二层
# 卷积层2: 卷积核3*3, 共64个
model.add(Conv2D(64,(3,3), activation='relu'))
# 池化层1: 采样2*2
model.add(MaxPooling2D((2, 2)))

# Flatten层: 连接卷积层与全连接层
model.add(Flatten())
# 全连接层
model.add(Dense(64, activation='relu'))
# 输出层
model.add(Dense(10))

# 打印网络结构
model.summary()
5.编译模型
# 编译模型
# 优化器、损失函数、矩阵
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

6.训练模型
# 训练模型
history = model.fit(train_images, train_labels, epochs=10,
                    validation_data=(test_images, test_labels))

7.预测
# 预测
pre_outputs = pre = model.predict(test_images)

活动地址:CSDN21天学习挑战赛

转载请注明:文章转载自 www.051e.com
本文地址:http://www.051e.com/it/1033814.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 ©2023-2025 051e.com

ICP备案号:京ICP备12030808号