ニートがプログラミングするブログ(はてな出張所)

ニートがプログラミングするブログです。今は主にコンピュータビジョンに関することをやっています。

LSGAN-AE (Least Squares Generative Adversarial Networks with AutoEncoder)

追記:コードを修正したので記事も修正しました。

追記2:他のデータセットで試したらうまくいかなかったので、もう少し修正します。

 

まだ研究中なので間違ってる部分があるかもしれません。
LSGAN(Least Squares Generative Adversarial Networks)とAE(AutoEncoder)を組み合わせることで、それなりにディープ(generatorだけで14層)なネットワークでも学習させることに成功しました。
ただし単純に組み合わせただけではうまくいかないので、いろいろと工夫が必要になります。
この記事ではその工夫について書いていきます。

・ネットワーク構造
LSGANで使うネットワークはgeneratorとdiscriminator、AEで使うネットワークはencoderとdecoderです。
このうちgeneratorとdecoderの重みを共有させることで、LSGANとAEをくっつけます。
その際、generator/decoderとdiscriminator/encoderは対称にして、学習が必要なパラメータ数を同じにします。
ただし単純に入力チャネル数と出力チャネル数を反対にした重みを使うだけではうまくいきません。
そこで演算の順番も対称にします。
具体的には、generator/decoderではupsample->convolution->bias->activationと進みますが、discriminator/encoderではactivation->bias->convolution->pooling、という順番で演算を行います。

・重みの初期化
重みWはHeの方法で初期化し、bは0で初期化します。

・アップサンプル
オリジナルのDCGANではupsampleの方法としてtransposed convolutionが使われていますが、ここではzero padding upsampleを使います。
(実際はどういう名前なのか知らないので、勝手にこういう名前にしました。)
通常のupsampleで縦横を2倍にする場合には次のようになります。
元のデータ
1 2 3
4 5 6
2倍にしたデータ
1 1 2 2 3 3
1 1 2 2 3 3
4 4 5 5 6 6
4 4 5 5 6 6
一方でzero padding upsampleは次のようになります。
元のデータ
1 2 3
4 5 6
2倍にしたデータ
1 0 2 0 3 0
0 0 0 0 0 0
4 0 5 0 6 0
0 0 0 0 0 0
poolingで例えると、上がaverage poolingで下がmax poolingみたいな感じです。
zero padding upsampleの実装方法は、全てが値が0のチャネルを追加して、Pixel Shufflerで並べ替えることで実現しています。

・プーリング
poolingはmax poolingを使います。

・畳み込み
convolutionはストライドが1のものを使います。
generatorにおいてupsampleした直後は5x5のカーネルを使い、それ以外では3x3カーネルを使います。

・活性化関数
generator/decoderの中間層ではmax(-1.0,x)を使い、出力層ではidentity(x)を使います。
discriminator/encoderの最初の層はidentity(x)を使い、それ以降はmax(0.0,x+0.5)を使います。
reluはmax(0.0,x)と表されるので、上の活性化関数はreluを左下にシフトしたもので、下の方はreluを左にシフトした活性化関数です。

・入力
画像は255.0で除算して[0,1]とします。
ノイズは[-1,1]の一様乱数です。

・損失関数
generatorとdecoderは重み共有しているので、損失関数は3個必要となります。

encoderの損失関数は次のようになります。
loss_enc = (Enc(Gen(z))-z)^2 + weight_decay
weight_decay = W^2*const(W^2)
ノイズzをgeneratorに通して生成した画像をencoderに通してもとのzとの二乗誤差を最小にします。
weight decayはWにしか適用せず、bには適用しません。
また、最終層のWにも適用しません。
weight decayを追加するときに、元の損失関数とのバランスをとるために、適当な値(0.01~0.00001)を掛けます。
この値をどのくらいにすればいいか分からなかったので、重み自身の二乗を定数として掛けることでバランスをとるようにしました。
定数なので、逆伝播の計算時には微分しません。
TensorFlowではtf.stop_gradient()で定数化します。

discriminatorの損失関数は次のようになります。
loss_dis = [(Dis(x)-1.0)^2;(Dis(x_fake)-(-1.0))^2;(Dis(x_dec)-(0.0))^2] + weight_decay
少し見づらいですが、説明していきます。
[x;y;z]はxとyとzを一つのミニバッチまとめにしていることを意味します。
別々に計算して後で足し合わせる、という方法ではうまくいきません。
また、x_fakeはあらかじめノイズzをgeneratorに通してnumpy arrayにしたものを再びネットワークに通しています。
x_decも同様にあらかじめ本物の画像xをencoderに通してdecoderに通したものをnumpy arrayにしています。
そしてxとx_fake、x_decをnp.concatenate()で一つにまとめてdeiscriminatorに入力します。
Dis(x)は1を出力するようにして、Dis(x_fake)とDis(x_dec)は-1を出力するようにします。
LSGANの論文では1と0でもいいように書いてありますが、そちらではうまくいきませんでした。
ちなみにdiscriminatorとgeneratorは対称になっているので、discriminatorの出力はgeneratorへの入力と同じ次元だけあります。
weight decayはencoderと同様です。

generatorの損失関数は次のようになります。
loss_gen = (Dis(Gen([z;Enc(x)]))-[0.0;Dis(x)])^2 + (Dec(Enc(x))-x)^2/const(mean( (Dec(Enc(x))-x)^2)*const(mean( (Dis(x)-0.0)^2)  )
右辺の左項は本物画像xをあらかじめencoderに通したものとノイズzを一つにまとめてgeneratorに通して画像化したものをdiscriminatorに通して、出力値が0と本物画像をdiscriminatorに通したときの値となるようにします。
右辺の右項は本物画像xの復元誤差を最小にしています。
左項とのバランスを取るために二乗誤差の平均の逆数を定数化したものと、本物画像をdiscriminatorに通したものの二乗平均を定数化したものを掛けています。

・最適化手法
学習率が0.0001のRMSPropを使います。

・ミニバッチ
ミニバッチに入れるデータ数と構成も重要です。
下の数字の整数倍となるように、ミニバッチを構成する必要があります。
encoderにはノイズを6とします。
disciminatorには本物画像2、偽物画像2、本物画像を復元したものを2とします。
generator/decoderには本物画像を6、ノイズ3、本物画像をencoderに通したものを3とします。

・結果
以前はsugyanさんのアイドル画像データセットを使っていましたが、もう配布をやめてしまったようなので使わないほうがいいと思い、これからはCelebAデータセットを使うようにしました。
CelebAデータセットは論文でも良く使われている、20万枚もの顔画像がある大規模データセットです。
今回はCelebAデータセットの中でも、画像サイズや位置がそろっているものを使いました。
そのままでは使いづらいので、各画像の中央150x150を切り取り64x64に縮小しました。
以下の画像は上から順番に、本物画像、偽物画像、本物画像のアナロジー、偽物画像のアナロジー、となっています。
本物画像のアナロジーは左端と右端が本物画像で、その隣が本物画像をencoderに通したものをdecoderに通したもので、真ん中が両端の画像のencoderの値を適当な比率で足し合わせた偽物画像となります。
偽物画像の中には微妙なものも混じっています。
一応100万ステップ程度学習させてますが、もっとやればもう少し良くなるかもしれません。

 本物画像

f:id:suzuichiblog:20170921134731j:plain

 偽物画像

f:id:suzuichiblog:20170921134756j:plain

 本物画像のアナロジー

f:id:suzuichiblog:20170921134810j:plain

 偽物画像のアナロジー

f:id:suzuichiblog:20170921134820j:plain


・感想
上で書いた全てのコツは経験則に基づいているので、理論的になぜそうなるのかは説明できません。
DCGANをさらにディープにしようと思い立って、いろいろ試しているうちに半年以上かかってしまいました。
なんか時間をかなり無駄にした感があります。
なぜかというと、この半年程度の間にWGAN-GPのようなディープでもきちんと学習できる方法が発表されてしまったからです。
まあ今更愚痴ってもしょうがないですが。
適当に書いたものですがソースコードgithubに上げておきました。

ただソースコードはpython2と古いtensorflowで書かれているので、python3や最新のtensorflowでは動かないかもしれません。