機械学習

StyleGAN2-adaの概要と実装

Nvidia社が開発したADA(Adaptive Discriminator Augumentation)という技術をStlyeGAN2に組み込んだものがStyleGAN2-adaになります。

GAN系の問題点として、何千枚もの画像で構成されたデータセットの用意と過学習をする前に学習を停止する必要があります。

過学習してしまうと、識別器のフィードバックが意味をなさなくなり、生成される画像が悪化します。これをモード崩壊と呼んだりもします。

これは、データ量に対してネットワークを過剰に学習させた場合に起こる現象で、下記の画像の黒い点の後のように品質は劣化するばかりです。

これらの解決策がADA(Adaptive Discriminator Augumentation)と呼ばれる手法です。

この手法が優れている点として、既存のGANのアーキテクチャに手を加える必要がなく適用できます。

既に知っている方もいるかと思いますが、深層学習のほとんどの分野で過学習対策として、data augumentationと呼ばれる作業をします。

具体的には、学習データ量を増やすために画像に回転やノイズ、色の変換などをします。

これをすることで、画像を新たに収集することなく多様なデータセットで学習することができるのですが、GANはそうはいきません。

生成器はこれらのdata augumentationに従った画像を生成するからです。(生成する画像の色が現実的におかしかったら偽物とバレてしまいます)

モデルの過学習を防ぐために、データ拡張しつつ同時にこれらの変換が生成された画像に影響を与えないようにしています。

ADA(Adaptive Discriminator Augumentation)とは?

端的にいってしまうと、生成モデルの新しい学習方法です。ADAでは10分の1の学習画像量で強力な生成モデルを学習することができます。

これによって多くの画像を必要としないアプリケーションが作成可能になります。

加えて、上記で紹介したGANでもdata augumentationが使える手法です。

基本的に、識別器にはdata augumentationされた全ての画像を使用することになりますが、ランダムに発生する確率で適用します。

また、この画像を使って識別器の性能も評価します。

生成器については変換していない画像だけを生成するように学習します。

data augumentationを使用してGANを学習させる方法は各変換の発生率が80%以下の場合のみ有効であると論文で結論づけられました。

確率が高いほど、より多くの拡張処理がされ、多様なデータセットが得られます。

この方法で、データセット数が足りない問題はカバーしたものの、初期のデータセット数に応じて異なるタイミングで発生する過学習の問題が残っています。

そこで、セグメンテーションを適応的に行う方法を考えました。

理想的なaugumentationの確率を決定するハイパーパラメーターを持つ代わりに、学習中にaugumentationの強さを制御するようにしました。

0から始めて、訓練データと検証データの差に基づいて値を調整します。

この値は過学習が起きているかどうかを示すものです。

検証データは、ネットワークが学習していない同じ種類の画像を集めた別物に過ぎません。

つまり、識別器がまだ見たことの無い画像で構成されてればよいです。

検証データの利用法としては結果の質を測定し、ネットワークの発散度を定量化し、過学習がどの程度進んでいるかを定量化するために使用されます。

下記の画像ではFFHQデータセットにおける学習サイズに対するADAの結果を表示しています。

FID指標を使用していますが、時間の経過とともによくなっており、過学習の問題にも至っていないのがわかります。

FIDは基本的に生成された画像と実際の画像の分布の間の距離を測定するものです。

上記の図は、生成された画像の品質を測定しています。低ければ低いほど良い結果が得られます。

使用しているFFHQデータセットはflickrから取得した7000人の高品質な顔画像が含まれています。

このデータセットはGANのベンチマークとして作成されたものです。

実際に使用する画像数が桁違いに少なくても、StyleGAN2の結果と一致することに成功しています。

この図は、1000~14000のデータ数に対する結果をプロットしたものです。

もっと深く内容を理解したい方は論文を参照してください。

また、コードの実装についてもtensorflowとpytorchどちらも公開されているので得意な方で実装することができます。

pytorch: https://github.com/NVlabs/stylegan2-ada-pytorch

tensorflow: https://github.com/NVlabs/stylegan2-ada

論文: https://arxiv.org/abs/2006.06676

実装については次回やるかも