CapsNet
Code
1.# 以下定义整个 CapsNet 的架构与正向传播过程
2.class CapsNet():
3. def __init__(self, is_training=True):
4. pass
5. # CapsNet 类中的build_arch方法能构建整个网络的架构
6. def build_arch(self):
7. # 以下构建第一个常规卷积层
8. with tf.variable_scope('Conv1_layer'):
9. # 第一个卷积层的输出张量为: [batch_size, 20, 20, 256]
10. # 以下卷积输入图像X,采用256个9×9的卷积核,步幅为1,且不使用
11. conv1 = tf.contrib.layers.conv2d(self.X, num_outputs=256,
12. kernel_size=9, stride=1,
13. padding='VALID')
14. assert conv1.get_shape() == [batch_size, 20, 20, 256]
15. # 以下是原论文中PrimaryCaps层的构建过程,该层的输出维度为 [batch_size, 1152, 8, 1]
16. with tf.variable_scope('PrimaryCaps_layer'):
17. # 调用前面定义的CapLayer函数构建第二个卷积层,该过程相当于执行八次常规卷积,
18. # 然后将各对应位置的元素组合成一个长度为8的向量,这八次常规卷积都是采用32个9×9的卷积核、步幅为2
19. primaryCaps = CapsLayer(num_outputs=32, vec_len=8, with_routing=False, layer_type='CONV')
20. caps1 = primaryCaps(conv1, kernel_size=9, stride=2)
21. assert caps1.get_shape() == [batch_size, 1152, 8, 1]
22. # 以下构建 DigitCaps 层, 该层返回的张量维度为 [batch_size, 10, 16, 1]
23. with tf.variable_scope('DigitCaps_layer'):
24. # DigitCaps是最后一层,它返回对应10个类别的向量(每个有16个元素),该层的构建带有Routing过程
25. digitCaps = CapsLayer(num_outputs=10, vec_len=16, with_routing=True, layer_type='FC')
26. self.caps2 = digitCaps(caps1)
27. # 以下构建论文图2中的解码结构,即由16维向量重构出对应类别的整个图像
28. # 除了特定的 Capsule 输出向量,我们需要蒙住其它所有的输出向量
29. with tf.variable_scope('Masking'):
30. #mask_with_y是否用真实标签蒙住目标Capsule
31. mask_with_y=True
32. if mask_with_y:
33. self.masked_v = tf.matmul(tf.squeeze(self.caps2), tf.reshape(self.Y, (-1, 10, 1)), transpose_a=True)
34. self.v_length = tf.sqrt(tf.reduce_sum(tf.square(self.caps2), axis=2, keep_dims=True) + epsilon)
35. # 通过3个全连接层重构MNIST图像,这三个全连接层的神经元数分别为512、1024、784
36. # [batch_size, 1, 16, 1] => [batch_size, 16] => [batch_size, 512]
37. with tf.variable_scope('Decoder'):
38. vector_j = tf.reshape(self.masked_v, shape=(batch_size, -1))
39. fc1 = tf.contrib.layers.fully_connected(vector_j, num_outputs=512)
40. assert fc1.get_shape() == [batch_size, 512]
41. fc2 = tf.contrib.layers.fully_connected(fc1, num_outputs=1024)
42. assert fc2.get_shape() == [batch_size, 1024]
43. self.decoded = tf.contrib.layers.fully_connected(fc2, num_outputs=784, activation_fn=tf.sigmoid)
44. # 定义 CapsNet 的损失函数,损失函数一共分为衡量 CapsNet准确度的Margin loss
45. # 和衡量重构图像准确度的 Reconstruction loss
46. def loss(self):
47. pass