読者です 読者をやめる 読者になる 読者になる

ニートがプログラミングするブログ(はてな出張所)

ニートがプログラミングするブログです。今は主にコンピュータビジョンに関することをやっています。

自己符号化器を用いたDCGANの事前学習

注意:まだ研究中なのでこのページの内容は間違っていたりしてると思います。
2016年12月2日追記: やはり背景画像のようなバリエーションが豊富なものに対してはうまくいきませんでした。そこで今は別の方法を検討中です。

2016年12月16日追記:いろいろと試してみましたが背景画像にはどれもうまくいきませんでした。

今さらですがDCGANに手を出してみました。
ただ試すだけでは味気ないので少しアレンジしてみました。
データはsugyanさんのアイドル画像データセットを使いました。

1.前置き
昨年の今頃にDCGAN(Deep Convolutional Generative Adversarial Networks)が話題となっていました。
DCGANとは、一様乱数の100次元ベクトルからきれいな画像を生成するgeneratorを作る、というものです。
例えばDCGANを使えば、アニメ顔を生成したり、アイドルの顔を生成したりできます。
DCGANの詳しい内容はこちらのページをご覧ください。
ところで、こちらのページによれば、DCGANは学習が難しいらしいです。
うまいこと学習させるにはパラメータを調整したりすることが必要みたいです。

この記事では、自己符号化器(Autoencoder)を使ってDCGANのgeneratorとdiscriminatorを事前学習します。
さらに、DCGANのdiscriminatorの出力層を複数にします。
その結果、調整が必要なパラメータは学習率と畳み込み層のフィルタ数くらいで、batch normalizationがなくてもAdamのbeta1を調整しなくてもDCGANをうまいこと学習させられるようになりました。

2.DCGANと自己符号化器
10年ほど前では、DeepLearningで扱うような深いネットワークはうまいこと学習できませんでした。
しかし、事前学習を行うことで、そういったネットワークでもきちんと学習が進むということが分かりました。
そこで私は、学習が難しいDCGANでも事前学習をすればきちんと学習が進むのではないか、と考えました。
事前学習の方法の一つに自己符号化器があります。
自己符号化器とは、入力と出力が同じになるようにネットワークを学習させるというものです。
ここでは、画像を入力して、100次元ベクトルに圧縮して、元の画像に近い画像を出力する、という自己符号化器を考えます。
入力画像を100次元ベクトルに圧縮する部分をエンコーダ、100次元ベクトルから出力画像を生成する部分をデコーダと呼びます。
このデコーダの部分はDCGANのgeneratorと同じような働きをします。
そのためDCGANのgeneratorの初期値に、通常の方法でトレーニングした自己符号化器のデコーダのウェイトをセットします。
そして一様乱数ベクトルをデコーダに通した出力と本物の画像とを区別できるようにDCGANのdiscriminatorをあらかじめ学習させておけば事前学習が完了します。

3.実装上の注意点
これを実装していく上での注意点は以下の3点です。
1点目はdiscriminatorの出力層の数です。
2点目は活性化関数についてです。
3点目は学習率についてです。
DCGANのdiscriminatorの出力は畳み込み層の数だけ作ります。
これはDeepID2+という方法をぱくったものです。
以前書いた顔認識の記事に書いてあるので、詳しくはそちらをご覧ください。
出力層が一つの状態で、エンコーダやデコーダ、generatorの出力、discriminatorの出力層を除く箇所全てで活性化関数にreluを使うと白い画像しか生成されなくなります。
これを避けるために、discriminatorに複数の出力層を設けて誤差を複数個所から逆伝播させます。
ただし全ての箇所でreluを使わなくても、DeepID2+のように出力数を複数にせずに自己符号化器で事前学習すると、最初はうまくいきますが最終的には白い画像を出力します。
次に、活性化関数についてです。
上で述べたような箇所全てでeluやtanhを使うと、格子状やドット状のノイズが出てきてきれいな画像が生成されません。
ただし一部で使うだけならばノイズが出てこない場合もあります。

最後は学習率についてです。
自己符号化器は、学習率が大きすぎると平均画像しか出力しなくなるので、1e-4程度に小さく設定します。
DCGANのgeneratorは学習率が高すぎると格子状のノイズしか出力されなくなるので、1e-4程度に小さくします。
discriminatorでも学習率が高すぎると、学習が進んでくるとgeneratorがおかしな画像を生成したりするようになるので、学習率を1e-4程度に小さくしておきます。
以下に学習初期のgeneratorの出力画像の一例を示します。
右にあるように、ノイズが多いが顔に見えるようなものが出てきていれば学習はうまくいきます。
ただしバッチサイズが10の場合で、数万ステップ程度待たなければまともな画像が出てこないので、我慢強く待ちましょう。
左にあるような灰色の画像や、顔っぽいものすらない画像が出力された場合は、学習率を下げることが必要になります。

f:id:suzuichiblog:20161201061558j:plain



実装について詳しりたい方はgithubにあるソースコードをご覧ください。
ただし、TensorFlowにもpythonにもまだ不慣れなので、おそらくソースコードは非効率な書き方をしてると思います。
dcgan_with_ae.pyはdiscriminatorの出力層が一つのもの、dcgan_with_ae_multi_output.pyは出力層が複数のものです。

4.結果と感想
実験結果は次の通りです。
事前学習がないと顔すら現れません。
一方で事前学習を行ったほうはきちんと顔が出来ています。

f:id:suzuichiblog:20161201061615j:plain

DCGANでよくある二つのzの遷移画像

f:id:suzuichiblog:20161201062958j:plain



DCGANを使うのは何番煎じかわかりませんが、自己符号化器を使うDCGANの事前学習は二番煎じくらいにはなっているんじゃないかなと思います。
(論文を探してないので分かりませんが、きっとどこかの誰かが既にやっているでしょう。)
もともとDCGANのパラメータ調整をしたくなかったのに、結局学習率というパラメータを調整しなければならなくなってしまいました。
自己符号化器やgeneratorの学習率は高くしすぎるとすぐにおかしな画像が出てくるので調整しやすいです。
一方でdiscriminatorは学習がかなり進んだ後でないとおかしな画像を出力しないので調整が面倒です。

今回うまくいったのは、顔という変化のパターンが少ないものだったからかもしれません。
そのため今後は風景などのパターン数が多そうなものでも学習できるかを確認していきたいです。

5.おまけ
自己符号化器とDCGANの設計図(?)を載せておきます。
まずは自己符号化器です。
自己符号化器は入力画像を5回畳み込んで100次元ベクトルにして、5回逆畳み込みして画像を出力します。
逆畳み込み(deconvolution)は転置畳み込み(?)(transposed convolution)とも呼ばれてるみたいです。
逆畳み込みについては[1]がわかりやすいです。
活性化関数は、エンコーダとデコーダの最後の層ではtanhで、それ以外ではreluです。

f:id:suzuichiblog:20161201061627j:plain



次はDCGANのgeneratorと自己符号化器のデコーダとの違いについてです。
最初にも書きましたが、DCGANのgeneratorと自己符号化器のデコーダは同じような働きをします。
しかし、入力が一様乱数か否か、あるいは生成画像がきれいかどうかという点で異なっています。
generatorの活性化関数は、最後の層ではtanhで、それ以外ではreluです。

f:id:suzuichiblog:20161201061640j:plain



最後はDCGANのdiscriminatorです。
4回畳み込んで、畳み込むごとにシグモイド関数で[0,1]の値を出力させます。
DCGANのdiscriminatorを学習させるときは、オリジナル画像を入力すると1を、偽画像を入力すると0を出力するようにします。
このようにすることで、DCGANのgeneratorが出力する画像を元画像のようにできます。
discriminatorの活性化関数は、畳み込み層ではreluで、出力層ではシグモイド関数(入力にウェイトを掛けてバイアスを足したものをそのままtf.nn.sigmoid_cross_entropy_with_logitsに入れています)です。

f:id:suzuichiblog:20161201061653j:plain



[1]Dumoulin, Vincent, and Francesco Visin. "A guide to convolution arithmetic for deep learning." arXiv preprint arXiv:1603.07285 (2016).