3.5 mxnet
    mxnet是亚马逊开发的深度学习库,它拥有类似于theano和tensorflow的数据流图,并且可以在常见的硬件平台上运行。mxnet还提供了r、c++、scala等语言的接口。我们以解决经典的手写数字识别的问题为例,介绍mxnet的基本使用方法,代码路径为:
    https://github.com/duoergun0729/adversarial_examples/blob/master/code/2-mxnet.ipynb
    1. 加载相关库
    加载处理经典的手写数字识别问题相关的python库:
    import mxnet as mx
    import logging
    2. 加载数据集
    mxnet中也针对常见的数据集进行了封装,免去了用户手工下载的过程并简化了预处理的过程:
    mnist = mx.test_utils.get_mnist()
    batch_size = 128
    train_iter = mx.io.ndarrayiter(mnist['train_data'], mnist['train_label'],
    batch_size, shuffle=true)
    val_iter = mx.io.ndarrayiter(mnist['test_data'], mnist['test_label'],
    batch_size)
    3. 定义网络结构
    使用与keras几乎完全相同的网络结构,只不过省略了dropout层。
    data = mx.sym.var('data')
    data = mx.sym.flatten(data=data)
    #全连接
    fc1 = mx.sym.fullyconnected(data=data, num_hidden = 512)
    act1 = mx.sym.activation(data=fc1, act_type="relu")
    fc2 = mx.sym.fullyconnected(data=act1, num_hidden = 512)
    act2 = mx.sym.activation(data=fc2, act_type="relu")
    fc3 = mx.sym.fullyconnected(data=act2, num_hidden=10)
    # softmax输出
    mlp = mx.sym.softmaxoutput(data=fc3, name='softmax')
    mlp_model = mx.mod.module(symbol=mlp, context=ctx)
    可视化网络结构,如图3-8所示,值得一提的是mxnet自带的可视化工具非常便于使用。
    import matplotlib.pyplot as plt
    mx.viz.plot_network(mlp).view()
    图3-8 mxnet处理mnist的网络结构图
    4. 定义损失函数和优化器
    损失函数使用交叉熵crossentropyloss,优化器使用sgd。
    5. 训练与验证
    mxnet的训练和验证过程是分开的,训练阶段加载优化器的配置,可以指定每训练100个批次,打印中间结果。
    mlp_model.fit(train_iter,
    eval_data=val_iter,
    optimizer='sgd',
    optimizer_params={'learning_rate':0.1},
    eval_metric='acc',
    batch_end_callback = mx.callback.speedometer(batch_size, 100),
    num_epoch=20)
    经过20轮的训练后,在测试集上验证准确度。
    test_iter = mx.io.ndarrayiter(mnist['test_data'], mnist['test_label'], batch_size)
    acc = mx.metric.accuracy()
    mlp_model.score(test_iter, acc)
    print(acc)
    最终在测试集上准确度达到了97.62%。
    info:root:epoch[19] train-accuracy=0.996438
    info:root:epoch[19] time cost=4.017 info:root:epoch[19] validation-accuracy=0.976167
    evalmetric: {'accuracy': 0.9761669303797469}
    mxnet保存的模型文件后缀为parms。
    mlp_model.save_checkpoint('models/mxnet.parms',20)

章节目录

智能系统与技术丛书·AI安全之对抗样本入门所有内容均来自互联网,一曲书屋只为原作者兜哥的小说进行宣传。欢迎各位书友支持兜哥并收藏智能系统与技术丛书·AI安全之对抗样本入门最新章节