WonderPlanet Tech Blog

アドバンストテクノロジー部やCTO室のメンバーを中心に最新の技術情報を発信しています。

Kerasでメモリに乗り切らないぐらいの大量データを学習するたった1つの方法

こんにちは。アドバンストテクノロジー部のR&Dチーム所属岩原です。
今回はKerasを使ってメモリに乗り切らないぐらいの大量データを学習させる方法について紹介したいと思います。
個人的にハマったポイントなので、同じように困ってる方々の力になれれば…と思ってます。
Kerasとは何ぞや、とか使い方云々はまた別途記事を書きたいと思います。

対象読者

  • Kerasを使ってある程度の学習は出来る人
  • Pythonがある程度読める人
  • Unix系OSでKerasを動かしている人

今回はモデルの構築などは省略しています。

確認環境

  • Python:3.6.1
  • Keras:2.0.8
  • tensorflow-gpu:1.3.0 (今回は特に関係ありません)

想定場面

それでは、具体的な場面を想定してみましょう。

  • クエストのログを使って学習したい。クリア or notのデータがcsvデータとして手元にある。
  • クレンジングや怪しいデータの削ぎ落としをしても、データサイズが100GB近くになってしまっている。
  • 使えるマシンのメモリは16GBしかない。
  • AWSやGCPのGPUインスタンスが使えない環境である。

メモリ16GBのマシンでは、とてもじゃないけど100GBのデータは乗りませんね。どうしましょう。

いきなり結論

モデルのfit_generatorメソッドと、keras.utils.data_utils.Sequenceを実装したクラスを使います。
それぞれの定義はSequentialモデル - Keras Documentationfit_generatorと、ユーティリティ - Keras Documentationを参照していただきたいです。

いきなりコード(抜粋)

keras.utils.data_utils.Sequenceを実装したクラス

from keras.utils import Sequence
from pathlib import Path
import pandas
import numpy as np
from keras.utils import np_utils

class CSVSequence(Sequence):
    def __init__(self, kind, length):
        # コンストラクタ
        self.kind = kind
        self.length = length
        self.data_file_path = str(Path(download_path) / self.kind / "splited" / "split_data_{0:05d}.csv")

    def __getitem__(self, idx):
        # データの取得実装
        data = pandas.read_csv(self.data_file_path.format(idx), encoding="utf-8")
        data = data.fillna(0)
        
        # 訓練データと教師データに分ける
        x_rows, y_rows = get_data(data)
        
        # ラベルデータのカテゴリカル変数化
        Y = np_utils.to_categorical(y_rows, nb_classes) 
        X = np.array(x_rows.values)
        
        return X, Y

    def __len__(self):
        # 全データの長さ
        return self.length

    def on_epoch_end(self):
        # epoch終了時の処理
        pass

fit_generatorの呼び出し

from pathlib import Path
import multiprocessing

# csvダウンロード先パス
download_path = "/data"

# 同時実行プロセス数
process_count = multiprocessing.cpu_count() - 1


base_dir = Path(download_path)

# 訓練データ
train_data_dir = base_dir / "log_quest" / "splited"
train_data_file_list = list(train_data_dir.glob('split_data_*.csv'))
train_data_file_list = train_data_file_list

#検証用データ
val_data_dir = base_dir / "log_quest_validate" / "splited" 
val_data_file_list = list(val_data_dir.glob('split_data_*.csv'))
val_data_file_list = val_data_file_list

history = model.fit_generator(CSVSequence("log_quest", len(train_data_file_list)),
                steps_per_epoch=len(train_data_file_list), epochs=1, max_queue_size=process_count * 10,
                validation_data=CSVSequence("log_quest_validate", len(val_data_file_list)), validation_steps=len(val_data_file_list),
                use_multiprocessing=True, workers=process_count)

解説

前提

このコードでは、csvファイルは訓練用データディレクトリlog_questと検証用データディレクトリlog_quest_validateに分かれて入っている状態です。
さらに、それを並列で処理しやすいように、1000レコードずつ分割してsplitedディレクトリに入っています。
構造としてはこんな感じです。

/data/
├── log_quest
│   └── splited
│       ├── split_data_**.csv
└── log_quest_validate
     └── splited
        ├── split_data_**.csv

またsplit_data_**.csv は、正確にはsplit_data_{0:05d}.csvという命名形式に沿ってファイルを分割しています。
そのため、ファイルリストをCSVSequenceに渡さず、並列で処理を行っても競合したりせずに処理が可能になっています。

fit_generatorについて

このfit_generatorメソッドは、Pythonのジェネレータが生成するデータを受け取って学習を行うメソッドです。
学習画像の無限生成(画像の前処理 - Keras Documentation)に使われているのをネットで目にしますが、
自前のジェネレータ関数などでも問題なかったりします。
したがって、ジェネレータ関数で動的にデータを読み込んで返すようにすれば、メモリに乗り切らないぐらいの大量データでもどうにか返すことが出来ます。
しかし、今回はジェネレータではなく、CSVSequenceというkeras.utils.data_utils.Sequenceクラスを実装した独自クラスを指定しています。

また、fit_generatorメソッドは引数workersに1より大きい数字を指定すると並列処理をしてくれるようになります。
その場合、引数use_multiprocessingがTrueならマルチプロセス、Falseならマルチスレッドで並列処理を行います。

keras.utils.data_utils.Sequence とは?

fit_generatorメソッドで並列処理をしやすいようにしてくれるユーティリティクラスです。
ジェネレータでも並列で処理が呼ばれるのですが、並列処理なので呼び出し順序が不定であり、気をつけないと同じデータを学習してしまったりします。
それを防ぐユーティリティクラスがkeras.utils.data_utils.Sequenceです。
keras.utils.data_utils.Sequence クラスには実装すべきメソッドが3つあります(コンストラクタ合わせると4つ)。
各処理の詳細を解説します。

  • def __init__(self, kind, length)
    コンストラクタです。
    今回の例では、引数としてkind:ファイルの種別(訓練、検証)、length:ファイルリストの長さを受け取ります。
    また、コンストラクタ内でファイルパスの雛形を作成してます。

  • def __getitem__(self, idx)
    学習データを返すメソッドです。
    idxは要求されたデータが何番目かを示すインデックス値が入ります。
    (訓練データ, 教師データ)のタプルか、(訓練データ, 教師データ, sample_weights)のタプルで値を返す必要があります。
    また、タプルのそれぞれの要素はnumpy配列である必要があり、同じ要素数で揃える必要があります。
    このnumpy配列のサイズがバッチサイズとなります。
    今回の例では、コンストラクタで生成したファイルパスの雛形を元にファイルパスを作成し、
    それに該当するcsvデータをpandasで読み込み、訓練データと教師データに分割して返しています。

  • def __len__(self)
    学習データ全体の長さを返すメソッドです。
    __getitem__メソッドのidxの最大値は、ここで返した長さ - 1が設定されます。
    今回の例では、コンストラクタで渡されたファイルリストの長さをそのまま返しています。

  • def on_epoch_end(self)
    epoch終了時に呼び出されるメソッドです。
    epoch終了ごとに何か処理をしたい場合はここに記述します。
    今回の例では何もしていません。

マルチプロセス並列処理時の注意!

コンストラクタはメインプロセスで実行されますが、他のメソッドは子プロセスで実行されます。
従って、コンストラクタで設定するインスタンス変数は読み取り専用(各インスタンスメソッド内でインスタンス変数を書き換えても反映されない)で、かつpickleでシリアライズ可能である必要があります。
また、あまりにも大きすぎるデータだと、子プロセスに渡す際のオーバーヘッドが大きすぎて、時間がかかりすぎることがあります。

その他注意点など

  • fit_generatorメソッドの引数max_queue_sizeはデータ生成処理を最大いくつキューイングしておくか、という設定になります。
    メモリと相談しながらサイズを引き上げると良いでしょう。
  • fit_generatorメソッドの引数steps_per_epochはfit関数と同じく1epochでの学習回数を表します。
    基本的にSequenceの__len__と同じ値を指定すると良いでしょう。
  • マルチプロセスによる並列処理は、Windowsでは動作しません
    test_muliprocessing failed on windows · Issue #6582 · fchollet/kerasなどたくさんissueも上がってますが、対応される気配はありません…。UbuntuなどのLinux系OSを使いましょう。