2020年4月25日土曜日

TensorFlow Model Optimization ToolkitでQuantization aware training in Keras

変更


2020.5.16 - QATでkeras modelとTF-Lite modelの精度の差がなくなった(問題が解消した)ので修正。


目的



tf.kerasでQuantization aware training(QAT)が実行できるTensorFlow Model Optimization Toolkit(tensorflow-model-optimization)のAPIがリリースされたので試してみる。




注意


現時点(2020.4.25)ではtensorflow-model-optimizationのQATのAPIは1.0リリース前(v0.3.0)で、experimentalなAPIも含まれている。このため、将来ここに記載された内容は変更になっている可能性がある。

また、TensorFlow 2.2.0-rc3以下では、TF-Lite integet quant modelに変換できないためtf-nightlyを使うことにも注意(参考資料のissue#38285も参照) 。

詳細はリファレンスも参照。


動機


以前、TF2.0のKerasでPost-training quantizationをブログに書いた。

しかし、Post-training quantizationは扱いやすい反面、精度が落ちる場合がある。精度をなるべく落とさない場合は、QATをしたいが、tf.kerasではサポートされていなかった。

今回、tensorflow-model-optimizationでQATのAPIがリリースされた。まずはImage ClassificationのモデルでQATを行い、TF-Lite Integer quant modelやEdge TPU Modelへの変換を試してみたかったので、今回やったことをまとめてみる。


参考資料



環境・バージョン情報


Google Colaboratoryで確認できるnotebookを作成した。
ただし、TensorFlow 2.2.0-rc3以下では、TF-Lite modelに変換できないためtf-nightly-gpuを使う。
  • tf-nightly-gpu: 2.2.0.dev20200423
  • tensorflow-model-optimization: 0.3.0


KerasのMobileNet v2でQAT


今回、tf.kerasのMobilenet v2をQATし、TF-Lite modelに変換する。
MobileNet v2のモデルでtf_flower datasetを学習する。tfsdなどの使い方はTF2.0のKerasでPost-training quantizationも参照。
Google Colaboratoryで実行できるnotebookは以下。

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

インストール


tensorflow-model-optimizationのインストールはpipでインストール可能。
$ pip install tensorflow-model-optimization

Import


以下でtensorflow-model-optimizationをimport。
import tensorflow_model_optimization as tfmo

モデルの生成


Functional APIやSequential modelを使ってモデルを生成する。
keras.applicationsで用意されているモデルを使う場合、Functional APIで構築する(下記の注意を参照)。
  1. def setup_mobilenet_v2_model():
  2. base_model = MobileNetV2(include_top=False,
  3. weights='imagenet',
  4. pooling='avg',
  5. input_shape=(IMG_SIZE, IMG_SIZE, 3))
  6. x = base_model.output
  7. x = tf.keras.layers.Dense(info.features['label'].num_classes,
  8. activation="softmax")(x)
  9. model = tf.keras.Model(inputs=base_model.input, outputs=x)
  10.  
  11. return model

モデル生成時の注意


Sequential modelの場合、入れ子となるようなモデルを構築すると量子化モデルの生成時にエラーとなる。
下記のようにモデルを定義するとエラー(ValueError: Quantizing a tf.keras Model inside another tf.keras Model is not supported.)が発生してしまうので注意。
  1. def setup_mobilenet_v2_model():
  2. base_model = MobileNetV2(include_top=False,
  3. weights='imagenet',
  4. pooling='avg',
  5. input_shape=(IMG_SIZE, IMG_SIZE, 3))
  6. model = tf.keras.Sequential([base_model,
  7. tf.keras.layers.Dense(info.features['label'].num_classes,
  8. activation='softmax')])
  9.  
  10. return model

サポートモデルの注意


現時点でサポートされているImage classification modelはここ(Image classification with tools)を参照。MobileNet、ResNetあたり。
DensNetは量子化モデルの生成時にエラーとなってしまう。サポート状況は改善されるので要確認。


モデルの学習(Not QAT)


QATのを行う前にまずは通常の学習を行う。
QATは通常より学習の時間がかかるため事前の学習を行っておいた方がよい。
  1. history = model.fit(train.repeat(),
  2. epochs=10,
  3. steps_per_epoch=steps_per_epoch,
  4. validation_data=validation.repeat(),
  5. validation_steps=validation_steps)

なお、MobileNet v2で、おおよそ1Stepあたりの学習時間は
  • QATでない通常の学習時: 20秒
  • QAT時: 40秒
なので、約2倍の差があった。


モデルの学習(QAT)


学習済みのモデルを使って量子化モデルへの生成、QATを行う。

まず、quantize_model APIでモデル全体を量子化する。
なお、APIとしては一部のレイヤーのみを量子化することも可能。
量子化したあとは、compileが必要になる。
q_aware_model = tfmo.quantization.keras.quantize_model(model)
q_aware_model.compile(optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate),
                      loss = 'sparse_categorical_crossentropy',
                      metrics = ["accuracy"])

モデルのサマリを表示してみる。
各Layerがquantize_xxxとなり、量子化モデルになっていることがわかる。
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
quantize_layer (QuantizeLayer)  (None, 224, 224, 3)  3           input_1[0][0]                    
__________________________________________________________________________________________________
quant_Conv1_pad (QuantizeWrappe (None, 225, 225, 3)  1           quantize_layer[1][0]             
__________________________________________________________________________________________________
quant_Conv1 (QuantizeWrapper)   (None, 112, 112, 32) 929         quant_Conv1_pad[0][0]            
__________________________________________________________________________________________________
quant_bn_Conv1 (QuantizeWrapper (None, 112, 112, 32) 129         quant_Conv1[0][0]                

 中略 

__________________________________________________________________________________________________
quant_block_16_project (Quantiz (None, 7, 7, 320)    307841      quant_block_16_depthwise_relu[0][
__________________________________________________________________________________________________
quant_block_16_project_BN (Quan (None, 7, 7, 320)    1283        quant_block_16_project[0][0]     
__________________________________________________________________________________________________
quant_Conv_1 (QuantizeWrapper)  (None, 7, 7, 1280)   412161      quant_block_16_project_BN[0][0]  
__________________________________________________________________________________________________
quant_Conv_1_bn (QuantizeWrappe (None, 7, 7, 1280)   5121        quant_Conv_1[0][0]               
__________________________________________________________________________________________________
quant_out_relu (QuantizeWrapper (None, 7, 7, 1280)   3           quant_Conv_1_bn[0][0]            
__________________________________________________________________________________________________
quant_global_average_pooling2d  (None, 1280)         3           quant_out_relu[0][0]             
__________________________________________________________________________________________________
quant_dense (QuantizeWrapper)   (None, 5)            6410        quant_global_average_pooling2d[0]
==================================================================================================

QAT(Quantization aware training)を行う。
これは、生成した量子化モデルで学習を行う。
  1. q_aware_history = q_aware_model.fit(train.repeat(),
  2. initial_epoch=10,
  3. epochs=70,
  4. steps_per_epoch=steps_per_epoch,
  5. validation_data=validation.repeat(),
  6. validation_steps=validation_steps)

学習開始直後は、loss, accuracyが悪化するが学習を進めるとQAT直前とほぼ同等になる。
事前に学習を行っていると精度がすぐに回復するようである。

学習の結果は以下。マゼンダの縦線より左が通常の学習、右がQATである。


QATの注意


(2020.5.16 変更)
Issue #368 でQATのKeras modelとTF-Lite modelで精度の差が異なる問題が修正された。これによってQATの収束とTF-Lite modelの精度は一致する。

自分があげたissueにあるとおり、QATの学習が少ないとKeras model(量子化モデル)とTF-Lite modelの精度に差が出てしまう。

作成したサンプルは10epochs程度でloss, accuracyが収束したように見え、keras modelではtestセットで0.98〜の精度がでる。しかし、TF-Lite modelに変換すると精度が0.20(tf_flowerは5クラス!)となってしまう。

原因はこちらのissueのコメントにあるとおり、量子化のMin / Maxの初期値が-6 / 6でかなり広く、Min / Maxが収束していない状態だと量子化時に損失が発生してしまうとのこと。
また、量子化の更新はMovingAverageQuantizerに定義されていて、tensorflowのmoving_averages(EMA: 移動平均)を使って値が更新される。このときのema_decay(減衰)が0.999で収束するまでにそれなりのstep数がかかってしまうとのことである。

複雑なタスクの場合、QATでloss, accuracyが回復するまでとEMAが収束するまではほぼ同じぐらいかもしれない。しかし、tf_flowerのような簡単?なタスクの場合は、loss, accuracyのほうがはやく収束してしまうため、このような結果となってしまう。

今後、このコメントのとおり、EMAの収束も監視できるようになるといいと思う。

また、モデルによってもMEAの収束は異なる模様。以下のようなモデルの場合は、10epochsでほぼ収束していた。
  1. def setup_cnn_model():
  2. # extract image features by convolution and max pooling layers
  3. inputs = tf.keras.Input(shape = (IMG_SIZE, IMG_SIZE, 3))
  4. x = tf.keras.layers.Conv2D(32, kernel_size=3, padding="same", activation="relu")(inputs)
  5. x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
  6. x = tf.keras.layers.Conv2D(64, kernel_size=3, padding="same", activation="relu")(x)
  7. x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
  8. # classify the class by fully-connected layers
  9. x = tf.keras.layers.Flatten()(x)
  10. x = tf.keras.layers.Dense(512, activation="relu")(x)
  11. x = tf.keras.layers.Dense(info.features['label'].num_classes)(x)
  12. x = tf.keras.layers.Activation("softmax")(x)
  13. model_functional = tf.keras.Model(inputs=inputs, outputs=x)
  14. return model_functional

QATのパラメータは変更が可能なので試してみてもよい。


かならず、TF-Lite modelでも精度を確認したほうがよいと思う。


TF-Lite Integer quantization Modelの変換


TF-Lit modelに変換する際は、MLIRのコンバーターが必要(ここでは、experimental_new_converter フラグをセットしているが不要になるはず)。
以下のパラメータで変換ができる。
  1. converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
  2.  
  3. converter.experimental_new_converter = True
  4. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  5. tflite_q_aware_integer_quant_model = converter.convert()
  6. with open(os.path.join(models_dir, 'mobilenet_v2_q_aware_integer_quant.tflite'), 'wb') as f:
  7. f.write(tflite_q_aware_integer_quant_model)

TF-Lite Modelのテスト


前回のPost-training quantizationで使用したコードがそのまま使用できる(ここでは省略)。
Post-training quantizationとQATで、Keras modelからの精度低下を確認。

Test accuracykeras modelTF-Lite model
Baseline0.9730.946
QAT1.0001.000

一応、Post-training quantization ではTF-Lite modelでの精度の低下があるが、QATでは低下がない。
ただ、今回のデータセット(tf_flower)だとちょっと適切でないかも。。。


Edge TPU Modelの変換


TF-Lite Integer quant modelに変換できたので、Edge TPU Modelにも変換できるはずである。
ただし、Edge TPU Compiler version 2.1.302470888では、
  • MobileNet v1は成功
  • MobileNet v2は失敗( Internal compiler error. Aborting! )
であった。。。

今後のEdge TPU Compilerのアップデートに期待!


最後に


TensorFlow Model Optimization Toolkitを使ってKerasモデルのQuantization aware trainingを行ってみた。少しハマったポイントはあるけど、簡単にQATができる。

KerasモデルでQATができるとモバイルやEdge TPUでの利用がさらに面白くなると思う。

これからもウォッチしてみたい。