Connections 仅仅作为Connection的集合对象,提供一些集合操作。
Node实现如下:
ConstNode对象,为了实现一个输出恒为1的节点(计算偏置项时需要)
Layer对象,负责初始化一层。此外,作为Node的集合对象,提供对Node集合的操作。
Connection对象,主要职责是记录连接的权重,以及这个连接所关联的上下游节点。
Connections对象,提供Connection集合操作。
Network对象,提供API。
至此,实现了一个基本的全连接神经网络。可以看到,同神经网络的强大学习能力相比,其实现还算是很容易的。
梯度检查
怎么保证自己写的神经网络没有BUG呢?事实上这是一个非常重要的问题。一方面,千辛万苦想到一个算法,结果效果不理想,那么是算法本身错了还是代码实现错了呢?定位这种问题肯定要花费大量的时间和精力。另一方面,由于神经网络的复杂性,我们几乎无法事先知道神经网络的输入和输出,因此类似TDD(测试驱动开发)这样的开发方法似乎也不可行。
办法还是有滴,就是利用梯度检查来确认程序是否正确。梯度检查的思路如下:
对于梯度下降算法:
当然,我们可以重复上面的过程,对每个权重Wji 都进行检查。也可以使用多个样本重复检查。
至此,会推导、会实现、会抓BUG,你已经摸到深度学习的大门了。接下来还需要不断的实践,我们用刚刚写过的神经网络去识别手写数字。
神经网络实战——手写数字识别
针对这个任务,我们采用业界非常流行的MNIST数据集。MNIST大约有60000个手写字母的训练样本,我们使用它训练我们的神经网络,然后再用训练好的网络去识别手写数字。
手写数字识别是个比较简单的任务,数字只可能是0-9中的一个,这是个10分类问题。
超参数的确定
我们首先需要确定网络的层数和每层的节点数。关于第一个问题,实际上并没有什么理论化的方法,大家都是根据经验来拍,如果没有经验的话就随便拍一个。然后,你可以多试几个值,训练不同层数的神经网络,看看哪个效果最好就用哪个。嗯,现在你可能明白为什么说深度学习是个手艺活了,有些手艺很让人无语,而有些手艺还是很有技术含量的。
不过,有些基本道理我们还是明白的,我们知道网络层数越多越好,也知道层数越多训练难度越大。对于全连接网络,隐藏层最好不要超过三层。那么,我们可以先试试仅有一个隐藏层的神经网络效果怎么样。毕竟模型小的话,训练起来也快些(刚开始玩模型的时候,都希望快点看到结果)。