3.5 MXNet
3.5 mxnet
mxnet是亚马逊开发的深度学习库,它拥有类似于theano和tensorflow的数据流图,并且可以在常见的硬件平台上运行。mxnet还提供了r、c++、scala等语言的接口。我们以解决经典的手写数字识别的问题为例,介绍mxnet的基本使用方法,代码路径为:
https://github.com/duoergun0729/adversarial_examples/blob/master/code/2-mxnet.ipynb
1. 加载相关库
加载处理经典的手写数字识别问题相关的python库:
import mxnet as mx
import logging
2. 加载数据集
mxnet中也针对常见的数据集进行了封装,免去了用户手工下载的过程并简化了预处理的过程:
mnist = mx.test_utils.get_mnist()
batch_size = 128
train_iter = mx.io.ndarrayiter(mnist['train_data'], mnist['train_label'],
batch_size, shuffle=true)
val_iter = mx.io.ndarrayiter(mnist['test_data'], mnist['test_label'],
batch_size)
3. 定义网络结构
使用与keras几乎完全相同的网络结构,只不过省略了dropout层。
data = mx.sym.var('data')
data = mx.sym.flatten(data=data)
#全连接
fc1 = mx.sym.fullyconnected(data=data, num_hidden = 512)
act1 = mx.sym.activation(data=fc1, act_type="relu")
fc2 = mx.sym.fullyconnected(data=act1, num_hidden = 512)
act2 = mx.sym.activation(data=fc2, act_type="relu")
fc3 = mx.sym.fullyconnected(data=act2, num_hidden=10)
# softmax输出
mlp = mx.sym.softmaxoutput(data=fc3, name='softmax')
mlp_model = mx.mod.module(symbol=mlp, context=ctx)
可视化网络结构,如图3-8所示,值得一提的是mxnet自带的可视化工具非常便于使用。
import matplotlib.pyplot as plt
mx.viz.plot_network(mlp).view()
图3-8 mxnet处理mnist的网络结构图
4. 定义损失函数和优化器
损失函数使用交叉熵crossentropyloss,优化器使用sgd。
5. 训练与验证
mxnet的训练和验证过程是分开的,训练阶段加载优化器的配置,可以指定每训练100个批次,打印中间结果。
mlp_model.fit(train_iter,
eval_data=val_iter,
optimizer='sgd',
optimizer_params={'learning_rate':0.1},
eval_metric='acc',
batch_end_callback = mx.callback.speedometer(batch_size, 100),
num_epoch=20)
经过20轮的训练后,在测试集上验证准确度。
test_iter = mx.io.ndarrayiter(mnist['test_data'], mnist['test_label'], batch_size)
acc = mx.metric.accuracy()
mlp_model.score(test_iter, acc)
print(acc)
最终在测试集上准确度达到了97.62%。
info:root:epoch[19] train-accuracy=0.996438
info:root:epoch[19] time cost=4.017 info:root:epoch[19] validation-accuracy=0.976167
evalmetric: {'accuracy': 0.9761669303797469}
mxnet保存的模型文件后缀为parms。
mlp_model.save_checkpoint('models/mxnet.parms',20)
mxnet是亚马逊开发的深度学习库,它拥有类似于theano和tensorflow的数据流图,并且可以在常见的硬件平台上运行。mxnet还提供了r、c++、scala等语言的接口。我们以解决经典的手写数字识别的问题为例,介绍mxnet的基本使用方法,代码路径为:
https://github.com/duoergun0729/adversarial_examples/blob/master/code/2-mxnet.ipynb
1. 加载相关库
加载处理经典的手写数字识别问题相关的python库:
import mxnet as mx
import logging
2. 加载数据集
mxnet中也针对常见的数据集进行了封装,免去了用户手工下载的过程并简化了预处理的过程:
mnist = mx.test_utils.get_mnist()
batch_size = 128
train_iter = mx.io.ndarrayiter(mnist['train_data'], mnist['train_label'],
batch_size, shuffle=true)
val_iter = mx.io.ndarrayiter(mnist['test_data'], mnist['test_label'],
batch_size)
3. 定义网络结构
使用与keras几乎完全相同的网络结构,只不过省略了dropout层。
data = mx.sym.var('data')
data = mx.sym.flatten(data=data)
#全连接
fc1 = mx.sym.fullyconnected(data=data, num_hidden = 512)
act1 = mx.sym.activation(data=fc1, act_type="relu")
fc2 = mx.sym.fullyconnected(data=act1, num_hidden = 512)
act2 = mx.sym.activation(data=fc2, act_type="relu")
fc3 = mx.sym.fullyconnected(data=act2, num_hidden=10)
# softmax输出
mlp = mx.sym.softmaxoutput(data=fc3, name='softmax')
mlp_model = mx.mod.module(symbol=mlp, context=ctx)
可视化网络结构,如图3-8所示,值得一提的是mxnet自带的可视化工具非常便于使用。
import matplotlib.pyplot as plt
mx.viz.plot_network(mlp).view()
图3-8 mxnet处理mnist的网络结构图
4. 定义损失函数和优化器
损失函数使用交叉熵crossentropyloss,优化器使用sgd。
5. 训练与验证
mxnet的训练和验证过程是分开的,训练阶段加载优化器的配置,可以指定每训练100个批次,打印中间结果。
mlp_model.fit(train_iter,
eval_data=val_iter,
optimizer='sgd',
optimizer_params={'learning_rate':0.1},
eval_metric='acc',
batch_end_callback = mx.callback.speedometer(batch_size, 100),
num_epoch=20)
经过20轮的训练后,在测试集上验证准确度。
test_iter = mx.io.ndarrayiter(mnist['test_data'], mnist['test_label'], batch_size)
acc = mx.metric.accuracy()
mlp_model.score(test_iter, acc)
print(acc)
最终在测试集上准确度达到了97.62%。
info:root:epoch[19] train-accuracy=0.996438
info:root:epoch[19] time cost=4.017 info:root:epoch[19] validation-accuracy=0.976167
evalmetric: {'accuracy': 0.9761669303797469}
mxnet保存的模型文件后缀为parms。
mlp_model.save_checkpoint('models/mxnet.parms',20)