Xceptionの引数pooling='ave'を使うとエラーが出る!?

自分用のメモ

keras内部にあるxceptionをimportして使用していたときのこと。

このxceptionのモデルの引数にはこのような多くの引数が存在する。

keras.applications.xception.Xception(include_top=True, weights='imagenet', input_tensor=None, input_shape=None, pooling=None, classes=1000)

この中でもpoolingは、

pooling: 特徴量抽出のためのオプショナルなpooling mode,include_topFalseの場合のみ指定可能.

    • None:モデルの出力が,最後のconvolutional layerの4階テンソルであることを意味しています.
    • 'avg':最後のconvolutional layerの出力にglobal average poolingが適用されることで,モデルの出力が2階テンソルになることを意味しています.
    • 'max':global max poolingが適用されることを意味します.

 

 との記述がある。

https://keras.io/ja/applications/

 

このpooling='avg'と指定したところ、エラーが発生。

どうやらこの層の前後で次元数が一致していないとのこと。

 

そこで、自ら

from keras.layers import GlobalAveragePooling2D

をimportし、

x = GlobalAveragePooling2D()(x)をプログラムに追加することで、

プログラムが無事通るようになった。

 

デフォルトのものはどこかに欠陥があるのだろうか。。。

 

 

 

 

 【アイデミープレミアムプラン】

人工知能を一から学びたい方はアイデミーがおすすめ!!

機械学習ディープラーニングpythonなどを基礎から学ぶことができます。