威望0
积分7946
贡献0
在线时间763 小时
UID1
注册时间2021-4-14
最后登录2024-11-21
管理员
- UID
- 1
- 威望
- 0
- 积分
- 7946
- 贡献
- 0
- 注册时间
- 2021-4-14
- 最后登录
- 2024-11-21
- 在线时间
- 763 小时
|
[mw_shl_code=python,true]import numpy as np
# 从keras的datasets导入数据集
from keras.datasets import mnist
# 全连接层,卷积层,池化层,扁平化,随机关闭神经元
from keras.layers import Dense,Dropout,MaxPool2D,Flatten,Convolution2D
#标签格式转化
from keras.utils import np_utils
# 导入顺序结构
from keras.models import Sequential
#导入Adma优化函数
from tensorflow.keras.optimizers import Adam
# 载入tf自带数据,得到训练集的数据和测试集的数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
'''x_train:60000X28X28;x_test:10000X28X28;y_train:1X60000;y_test:1X10000'''
'''
-1是自动匹配数据的个数60000,长宽深度分别28,28,1,再归一化
'''
x_train = x_train.reshape(-1,28,28,1) / 255.0
x_test = x_test.reshape(-1,28,28,1) / 255.0
# 转换为 one hot 格式
'''这里使用的numpy下的untils中的to_categorical方法把标签数据给分类
因为有10个数字,所以设置num_classes为10,也就是10个类'''
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10)
# 创建模型
model = Sequential()
#定义卷积核
model.add(Convolution2D(
input_shape=(28,28,1),#输入平面大小
filters=32,#卷积核大小
kernel_size=5,#卷积窗口大小为5
strides=1,#步长为1
padding='same',#边缘补0是same,不补零是valid
activation='relu'#激活函数
))
#定义第一个池化层
model.add(MaxPool2D(
pool_size=2,#池化窗口大小
strides=2,#步长为2
padding='same'
))
#第二个卷积层
model.add(Convolution2D(64,5,strides=1,padding='same',activation='relu'))
#第二个池化层
model.add(MaxPool2D(2,2,'same'))
#将输出扁平化
model.add(Flatten())
#第一个全连接层
model.add(Dense(1024,activation='relu'))
#Drop,百分之50关闭神经元
model.add(Dropout(0.5))
#第二个全连接层
model.add(Dense(10,activation='softmax'))
# 定义优化器
adam = Adam(learning_rate=1e-4)
# 定义优化器以及loss function即损失函数,训练过程中计算准确率
model.compile(
optimizer=adam, # 使用的优化函数
loss='categorical_crossentropy',
metrics=['accuracy'] # 计算准确率
)
# 训练模型,使用训练集
'''batch_size=64表示每次会训练64张图片,把60000张图片训练完为1个周期
epochs是迭代周期,所以这里设置要训练完10个周期'''
model.fit(x_train, y_train, batch_size=64, epochs=10)
# 评估模型,使用测试集
loss, accuracy = model.evaluate(x_test, y_test)
# 打印loss和accuracy的值
print('loss:', loss)
print('accuracy:', accuracy)
[/mw_shl_code] |
|