使用tf.layers.batch_normalization()需要三步:
- 在卷积层将激活函数设置为None。
- 使用batch_normalization。
使用激活函数激活。
需要特别注意的是:在训练时,需要将第二个参数training = True
。在测试时,将training = False
。同时,在降低loss时候时候:update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = tf.train.AdamOptimizer(1e-3).minimize(loss) #使用AdamOptimizer优化器将损失函数降到最低
这里要注意如果不添加
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
会导致训练模式下的预测正确率很好,但是在预测模式测试集中最后计算的正确率非常低。
AlexNet的Tensorflow实现
环境
- python3.6
- tensorflow 0.10.0
建议使用anoconda安装,可以节省不少时间和方便不少
加载Cifar10的数据集
数据集使用cifar10,需要自行下载
def load_data(filename):
"""read data from data file"""
with open(filename,'rb') as f:
data = pickle.load(f,encoding='latin1')
return data['data'],data['labels']
class CifarData:
def __init__(self, filenames, need_shuffle):
all_data = []
all_labels = []
for filename in filenames:
data, labels = load_data(filename)
all_data.append(data)
all_labels.append(labels)
self._data = np.vstack(all_data)
self._data = self._data / 127.5 - 1
self._labels = np.hstack(all_labels)
self._num_examples = self._data.shape[0]
self._need_shuffle = need_shuffle
self._indicator = 0
if self._need_shuffle:
self._shuffle_data()
def _shuffle_data(self):
p = np.random.permutation(self._num_examples)
self._data = self._data[p]
self._labels = self._labels[p]
def next_batch(self, batch_size):
"""return batch_size examples as a batch"""
end_indicator = self._indicator + batch_size
if end_indicator > self._num_examples:
if self._need_shuffle:
self._shuffle_data()
self._indicator = 0
end_indicator = batch_size
else:
raise Exception("Have no more examples")
if end_indicator > self._num_examples:
raise Exception ("Batch size is larger than all examples")
batch_data = self._data[self._indicator:end_indicator]
batch_lebel = self._labels[self._indicator:end_indicator]
self._indicator = end_indicator
return batch_data, batch_lebel
网络部分代码,这里使用了一下手动实现batch_normalization,也就是批归一化的流程
def batch_normal(xs, out_size):
axis = list(range(len(xs.get_shape()) - 1))
n_mean, n_var = tf.nn.moments(xs, axes=axis)
scale = tf.Variable(tf.ones([out_size]))
shift = tf.Variable(tf.zeros([out_size]))
epsilon = 0.001
ema = tf.train.ExponentialMovingAverage(decay=0.9)
def mean_var_with_update():
ema_apply_op = ema.apply([n_mean, n_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(n_mean), tf.identity(n_var)
mean, var = mean_var_with_update()
bn = tf.nn.batch_normalization(xs, mean, var, shift, scale, epsilon)
return bn
下面是正常的网络流程
这里放一张网络结构图,我看目前搜索引擎很多都只有图但是没有这个最早期在CNN网络之前的代码,所以自己尝试写了一下,也算学习一下
AlexNet网络结构:
由于早期是显卡瓶颈导致的需要两张显卡做运算之后交叉数据,现在一张显卡就可以胜任这个工作了,所以和图上的结构有些许区别。
train_filenames = [os.path.join(CIFAR_DIR, 'data_batch_%d' % i) for i in range(1,6)]
test_filenames = [os.path.join(CIFAR_DIR, 'test_batch')]
train_data = CifarData(train_filenames, True)
x = tf.placeholder(tf.float32,[None,3072])
# [None], rg [0,6,5,3]
y = tf.placeholder(tf.int64,[None])
is_training = tf.placeholder(tf.bool,[])
x_image = tf.reshape(x,[-1,3,32,32])
x_image = tf.transpose(x_image, perm=[0, 2, 3, 1])
# 神经元图 feature_map, 输出图像
# 32*32
conv1 = tf.layers.conv2d(x_image,
48,#output channel number
(3,3),#kenel size
padding = 'same',
activation=tf.nn.relu,
name='conv1')
lrn1 = tf.nn.lrn(conv1, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
# 16*16
pooling1 = tf.layers.max_pooling2d(lrn1,
(2,2),
(2,2),
name='pool1')
# 16*16
conv2 = tf.layers.conv2d(pooling1,
96,#output channel number
(3,3),#kenel size
padding = 'same',
activation=None,
name='conv2')
# batch_normalization
bn = tf.layers.batch_normalization(conv2, training=is_training )
conv2 = tf.nn.relu(bn)
# 取消使用lrn层改成bn
# lrn2 = tf.nn.lrn(conv2, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
# 8*8
pooling2 = tf.layers.max_pooling2d(conv2,
(2,2),
(2,2),
name='pool2')
# 8*8
conv3 = tf.layers.conv2d(pooling2,
192,#output channel number
(3,3),#kenel size
padding = 'same',
activation=tf.nn.relu,
name='conv3')
# 8*8
conv4 = tf.layers.conv2d(conv3,
192,#output channel number
(3,3),#kenel size
padding = 'same',
activation=tf.nn.relu,
name='conv4')
# 8*8
conv5 = tf.layers.conv2d(conv4,
96,#output channel number
(3,3),#kenel size
padding = 'same',
activation=tf.nn.relu,
name='conv5')
bn = tf.layers.batch_normalization(conv5, training=is_training )
conv5 = tf.nn.relu(bn)
# 4*4
pooling5 = tf.layers.max_pooling2d(conv5,
(2,2),
(2,2),
name='pool2')
# [None, 1024]
flatten = tf.layers.flatten(pooling5)
y_1 = tf.layers.dense(flatten, 1024)
bn6 = batch_normal(y_1, 1024)
fc1 = tf.nn.relu(bn6)
# [None, 1024]
y_2 = tf.layers.dense(fc1, 1024)
bn7 = batch_normal(y_2, 1024)
fc2 = tf.nn.relu(bn7)
# [None, 10]
y_ = tf.layers.dense(fc2, 10)
# 交叉熵损失函数
# y_->softmax
# y -> one_hot
# loss = ylogy_
loss = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=y_)
# bool
predict = tf.argmax(y_, 1)
correct_prediction = tf.equal(predict, y)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float64))
with tf.name_scope('train_op'):
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)
init = tf.global_variables_initializer()
batch_size = 20
test_steps = 100
train_steps = 10000
with tf.Session() as sess:
sess.run(init)
for i in range(train_steps):
batch_data, batch_labels = train_data.next_batch(batch_size)
loss_val, accu_val, _ = sess.run([loss, accuracy, train_op], feed_dict={x: batch_data, y: batch_labels, is_training:True})
if (i+1) % 500 == 0:
print('[Train] Step: %d, loss: %4.5f, acc:%4.5f' % (i+1, loss_val, accu_val))
if (i+1) % 5000 == 0:
test_data = CifarData(test_filenames, False)
all_test_acc_val = []
for j in range(test_steps):
test_bach_data, test_batch_labels = test_data.next_batch(batch_size)
test_acc_val = sess.run([accuracy], feed_dict={x: test_bach_data, y: test_batch_labels, is_training:False})
all_test_acc_val.append(test_acc_val)
test_acc = np.mean(all_test_acc_val)
print("[Test] Step: %d, acc:%4.5f" %(i+1, test_acc))
下面是在运行了10000次之后的测试结果:
可以看出测试集正确率可以在73.7%左右。算是比较高的了在AlexNet这个网络下面。
7 comments
真棒!
哈哈哈,写的太好了https://www.cscnn.com/
看的我热血沸腾啊www.jiwenlaw.com
不错不错,我喜欢看 https://www.ea55.com/
想想你的文章写的特别好https://www.jiwenlaw.com/
叼茂SEO.bfbikes.com
叼茂SEO.bfbikes.com