conv2d、separableconv2d、depthwise conv2d的实现与keras中参数个数的区别,Conv2DSeparableConv2DDepthwiseConv2D,数量

发表时间:2020-05-20

Conv2D

对于一个普通卷积来说,通常卷积核的大小,上一层的channel与本层的channel对应了参数的数量。使用了bias的话,普通卷积 (Conv2d) 的参数为:

# p a r a m s = ( f i l t e r _ s i z e × f i l t e r _ s i z e × c h a n n e l l a s t + 1 ) × c h a n n e l t h i s \#params=(filter\_size\times filter\_size\times channel_{last}+1)\times channel_{this}

构建这样一个简单网络可以得到卷积层的参数为: (3*3*10+1)*32=2912

import keras
from keras.layers import Conv2D, Input
def test_model(input_size=(224, 224, 10)):
    _input = Input(shape=input_size, name='input')
    _output = Conv2D(32, kernel_size=(3, 3), activation='relu', padding='same')(_input)
    model = keras.models.Model(_input, _output)
    model.summary()
Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input (InputLayer)           (None, 224, 224, 10)      0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 224, 224, 32)      2912      
=================================================================
Total params: 2,912
Trainable params: 2,912
Non-trainable params: 0
_________________________________________________________________

DepthwiseConv2D

DepthwiseConv2D是指使用同一个卷积核在不同的channel上进行卷积输出,所以DepthwiseConv2D输出维度始终是与输入维度一样的,也不用指定 filter 大小,都一样的。(如果使用了 depth_multiplier 参数,参数数量乘以 depth_multiplier )。
参数数量计算公式:

# p a r a m s = [ ( f i l t e r _ s i z e × f i l t e r _ s i z e × 1 + 1 ) × c h a n n e l l a s t ] × d e p t h _ m u l t i p l i e r \#params=[(filter\_size\times filter\_size\times 1+1)\times channel_{last}]\times depth\_multiplier

将上面的网络中卷积改成DepthwiseConv2D的话( d e p t h _ m u l t i p l i e r depth\_multiplier 为默认值1),参数量为: (3*3*1+1)*10=100

def test_model(input_size=(224, 224, 10)):
    _input = Input(shape=input_size, name='input')
    _output = DepthwiseConv2D(kernel_size=(3, 3), activation='relu', padding='same')(_input)
    model = keras.models.Model(_input, _output)
    model.summary()
Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input (InputLayer)           (None, 224, 224, 10)      0         
_________________________________________________________________
depthwise_conv2d_1 (Depthwis (None, 224, 224, 10)      100       
=================================================================
Total params: 100
Trainable params: 100
Non-trainable params: 0
_________________________________________________________________

SeparableConv2D

使用SeparableConv2D的话,就是在DepthwiseConv2D的基础上,再进行 1x1 的卷积输出。相当于在SeparableConv2D的基础上,还有 (1x1xchanel_last+1)xchannel_this 的参数(但是在实现时候发现,只有当DepthwiseConv2D部分不用bias才是一样的)。
# p a r a m s = ( f i l t e r _ s i z e × f i l t e r _ s i z e × 1 ) × c h a n n e l l a s t + ( 1 × 1 × c h a n n e l l a s t + 1 ) × c h a n n e l t h i s \#params=(filter\_size\times filter\_size\times 1)\times channel_{last}+\\(1\times 1\times channel_{last}+1)\times channel_{this}

以下网络的参数量计算: 3x3x10+(1x1x10+1)x32=442

def test_model(input_size=(224, 224, 10)):
    _input = Input(shape=input_size, name='input')
    _output = SeparableConv2D(32, kernel_size=(3, 3), activation='relu', padding='same')(_input)
    model = keras.models.Model(_input, _output)
    model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input (InputLayer)           (None, 224, 224, 10)      0         
_________________________________________________________________
separable_conv2d_1 (Separabl (None, 224, 224, 32)      442       
=================================================================
Total params: 442
Trainable params: 442
Non-trainable params: 0
_________________________________________________________________
微配音

文章来源互联网,尊重作者原创,如有侵权,请联系管理员删除。邮箱:417803890@qq.com / QQ:417803890


Python Free

邮箱:417803890@qq.com
QQ:417803890

皖ICP备19001818号
© 2019 copyright www.pythonf.cn - All rights reserved

微信扫一扫关注公众号:

联系方式

Python Free