WonderPlanet Tech Blog

アドバンストテクノロジー部やCTO室のメンバーを中心に最新の技術情報を発信しています。

ファインチューニングをやってみた

こんにちは。アドバンストテクノロジー部のR&Dチーム所属岩原です。
今回はファインチューニングについて色々と調査しました。

ファインチューニング(fine tuning)とは

既存のモデルの一部を再利用して、新しいモデルを構築する手法です。
優秀な汎用モデル(VGG16など)を使い、自分たち用のモデルを構築したり出来ます。
少ないデータ(といっても数十〜数百ぐらいは必要ですが)で、結構精度の良いモデルが構築できたりします。

全く違う方向性(写真画像系のモデルを元に、イラストの判定モデルを作るなど)だと余り効果が出てこないようですが、
元のモデルより更に詳細な特徴を抽出したい、などの用途だと効果が高いようです。

転移学習(transfer learning)という呼び方もされるみたいですが、使い分けとかどんな感じなんでしょうね?

実際にやってみた

環境

  • Ubuntu16.04(GTX1080Ti)
  • Keras 2.0.8
  • Tensorflow 1.3.0
  • nvidia-docker 1.0.1

AWSのGPU Computeインスタンス(p2やp3)でも使いまわせるようにDockerizeしています。

使用するデータセット

Food-101 -- Mining Discriminative Components with Random Forests
101ラベルの料理画像を計101000枚(1ラベル1000枚)用意しているデータセットです。
枚数としては心もとない気がしますが、ファインチューニングを試してみるにはちょうどよい枚数かと思います。

3つのデータパターン

データの数によってどこまで変わるか、という検証のため、パターンを3つ用意しました。

  • データ数を訓練データ100枚、検証データ25で行うパターン(パターン1)
    データ数を少なくし、ファインチューニングの効果を検証するパターンです。

  • パターン1のデータをImageDataGeneratorで水増ししたパターン(パターン2)
    限られたデータ数を水増しし、どこまで効果が出るのかを検証するパターンです。
    fit_generatorの引数steps_per_epochを2000(batch_sizeは32なので、2000 * 32で64000枚)、引数validation_stepsを500(同じく500 * 32で16000枚)に設定しました。

  • データセットの全てのデータ(訓練データ75750枚、検証データ25250枚)で学習を行うパターン(パターン3)
    データセット全てを学習&検証に回し、水増しとの違いを検証するパターンです。
    バッチサイズは32を設定しました。

3つのモデル

ファインチューニングの方法によってどこまで差が出るのかの検証のため、さらにモデルを3つ用意しました。

  • ピュアなVGG16(モデル1)
    重みを初期化したVGG16構造のモデルです。
    ファインチューニングしない場合の検証を行うパターンになります。

f:id:m_iwahara:20171212170350p:plain

Kerasを使用したコードはこんな感じになります。

def create_none_weight_vgg16_model(size):
    model_path = "./models/vgg16_none_weight.h5py"
    if not os.path.exists(model_path):
        input_tensor = Input(shape=(224,224,3))
        model = VGG16(weights=None, include_top=True, input_tensor=input_tensor, classes=size)
        model.save(model_path) # 毎回ダウンロードすると重いので、ダウンロードしたら保存する
    else:
        model = load_model(model_path) 
    model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])
    return model
  • VGG16の全結合層のみを取り替え(モデル2)
    全結合層を取っ払い、新たに全結合層をくっつけたモデルです。
    全結合層以外の層の重みは学習済みのパラメータをそのまま使用し、学習しないようにします。
    全結合層のみファインチューニングした場合の検証を行うパターンになります。
    ついでに、全結合層にDropoutを付けてみたりしています。
    なお、最適化関数は学習率を極端に抑えたSGDを使用しています。

f:id:m_iwahara:20171212170417p:plain

Kerasを使用したコードはこんな感じになります。

def get_vgg16_model():
    model_path = "./models/vgg16.h5py"
    if not os.path.exists(model_path):
        input_tensor = Input(shape=(224,224,3))
        # 出力層側の全結合層3つをモデルから省く
        model = VGG16(weights='imagenet', include_top=False, input_tensor=input_tensor)
        model.save(model_path) # 毎回ダウンロードすると重いので、ダウンロードしたら保存する
    else:
        model = load_model(model_path)
    return model

def create_fullconnected_fine_tuning(classes):
    # vgg16モデルを作る
    vgg16_model = get_vgg16_model()

    input_tensor = Input(shape=(224,224,3))

    for layer in vgg16_model.layers:
        layer.trainable = False

    x = vgg16_model.output
    x = Flatten()(x)
    x = Dense(2048, activation='relu')(x)
    x = Dropout(0.5)(x)
    x = Dense(1024, activation='relu')(x)
    predictions = Dense(classes, activation='softmax')(x)
    model = Model(inputs=vgg16_model.input, outputs=predictions)


    model.compile(loss='categorical_crossentropy',
                  optimizer=SGD(lr=1e-4, momentum=0.9),
                  metrics=['accuracy'])
    return model
  • VGG16の最後の畳み込み層と全結合層を取り替え(モデル3)
    最後の畳み込み層の重みを初期化して学習するようにし、全結合層を取り替えたモデルです。
    上記以外の層は再学習をしないようにします。
    最後の畳み込み層と全結合層をファインチューニングした場合の検証を行うパターンになります。

f:id:m_iwahara:20171212170432p:plain

Kerasを使用したコードはこんな感じになります。

def create_last_conv2d_fine_tuning(classes):
    # vgg16モデルを作る
    vgg16_model = get_vgg16_model()

    input_tensor = Input(shape=(224,224,3))

    x = vgg16_model.output
    x = Flatten()(x)
    x = Dense(2048, activation='relu')(x)
    x = Dropout(0.5)(x)
    x = Dense(1024, activation='relu')(x)
    predictions = Dense(classes, activation='softmax')(x)
    model = Model(inputs=vgg16_model.input, outputs=predictions)
    # 最後の畳み込み層より前の層の再学習を防止
    for layer in model.layers[:15]: 
        layer.trainable = False

    model.compile(loss='categorical_crossentropy',
                  optimizer=SGD(lr=1e-4, momentum=0.9),
                  metrics=['accuracy'])
    return model

その他

過学習に陥る前に学習を止めるEarlyStoppingを導入済みです。

結果

上記データパターンとモデルパターンの組み合わせを検証。

パターン1

モデル1

精度

f:id:m_iwahara:20171213091424p:plain

損失

f:id:m_iwahara:20171213091443p:plain

モデル2

精度

f:id:m_iwahara:20171213091528p:plain

損失

f:id:m_iwahara:20171213091541p:plain

モデル3

精度

f:id:m_iwahara:20171213091609p:plain

損失

f:id:m_iwahara:20171213091622p:plain

パターン2

モデル1

精度

f:id:m_iwahara:20171213091648p:plain

損失

f:id:m_iwahara:20171213091700p:plain

モデル2

精度

f:id:m_iwahara:20171213091714p:plain

損失

f:id:m_iwahara:20171213091725p:plain

モデル3

精度

f:id:m_iwahara:20171213091740p:plain

損失

f:id:m_iwahara:20171213091749p:plain

パターン3

モデル1

精度

f:id:m_iwahara:20171213091803p:plain

損失

f:id:m_iwahara:20171213091813p:plain

モデル2

精度

f:id:m_iwahara:20171213091825p:plain

損失

f:id:m_iwahara:20171213091837p:plain

モデル3

精度

f:id:m_iwahara:20171213091902p:plain

損失

f:id:m_iwahara:20171213091914p:plain

結果まとめ

「データ数は多いほうが良い。ファインチューニングはゼロから学習させるよりもかなり有効で、収束も早い。最後の畳み込み層からファインチューニングした方が精度が良い。」という感じですね。
枚数が足りないのか、過学習の傾向はどれも見られますが…。
なお、1 epoch辺りの学習時間は モデル1 > モデル3 > モデル2 といった結果になりました。
ファインチューニングを行うと学習時間も節約できるのでおすすめです。

参考

VGG16のFine-tuningによる犬猫認識 (2) - 人工知能に関する断創録