2017年3月16日 更新

【機械学習勉強会】パターン認識と機械学習(PRML)第1章のまとめ Part.1

機械学習理論の名著である『パターン認識と機械学習』(Pattern Recognition and Machine Learning, PRML)の内容をまとめます.

22,296 view お気に入り 1

皆さんの予想はいかがでしたか?グラフで見てみると,一目瞭然ですね! 緑の曲線($y=\sin(2\pi x)$)を最もよく近似しているのは, 3つの選択肢の中では $M=3$ のとき,という結果になりました.

しかし,直感的には,「次数が高い方がよく近似できるだろう」と思いませんか. 調整できる多項式の係数 $w_i$ の数が増えるからです. また,「$M=9$ のグラフはたしかに緑の線からは外れているかもしれないが, 青いデータ点にはピッタリ誤差なく適合できている.それなら問題ないじゃないか」 と思うかもしれません.

しかし,先程述べた「汎化」という観点から考えると, 青い点(訓練データ)だけに特化して高い性能を発揮するのは,良いことではありません. 訓練データ以外のテストデータが来た時に,誤差が大きくなってしまうためです. 実際,次数ごとに多項式フィッティングを行うと,誤差は次のようになります.

次数ごとの訓練データ・テストデータの誤差

次数ごとの訓練データ・テストデータの誤差

$M$ が0から大きくなると,訓練データ(青)の誤差は単調に減っていく.
しかし,テストデータ(赤)の誤差は,途中までは誤差が減るものの,$M=9$のときは誤差が急激に大きくなる.
誤差は平均二乗平方根誤差 $E_{\text{RMS}}=\sqrt{2E(\bm w^*)/N}$.

$M$ が大きくなると自由度が高くなるため, 多少複雑なモデルでも柔軟に処理できます. ところが,自由度が高すぎると訓練データのノイズに引きずられ, 訓練データに合わせようと,係数 $w_i$ の値が非常に大きな値を取るようになります. その結果,テストデータにはとても合致しないような「変な」曲線を出すことになります. このような現象のことを過学習といいます. 機械学習において,過学習は避けるべき問題です.どうしたら避けられるのでしょうか.

過学習を避けるには

過学習を避ける最も単純な方法は,テストデータを増やすことです. 先程示した 9次式(自由度10) というのは,テスト点が10個しかなかったため, 相対的に見るとかなり多い自由度でした. しかし,テスト点を100個に増やせば,相対的に見れば 9次 というのは少ない自由度です. そのため,過学習のリスクを減らすことができます.

テスト点が100個のときの多項式フィッティング

テスト点が100個のときの多項式フィッティング

データ点が10個の時に比べ,フィッティング性能が高くなっている様子が分かる(過学習の問題が改善されている).

もちろんテストデータがいつでも潤沢にあるとは限りませんから, その他にも過学習を避ける方法としてベイズ的なアプローチを取る方法, 正則化項を付ける方法などが考えられています. ベイズ的な方法については後に詳しくお話しするとして, 今回は正則化について少しだけ紹介しましょう. (とはいっても,ベイズ的方法も,正則化と大きくつながっていますが…)

多項式曲線フィッティングでは,二乗誤差 \begin{align} E(\bm{w})=\dfrac{1}{2}\sum_{n=1}^N\bigl(y(x_n,\bm{w})-t_n\bigr)^2 \end{align} を最小化するんでした. $E$ は小さければ小さいほど好ましい推定ということです.

ところで,今までの議論から他にも「大きすぎると好ましくないもの」があります. それは,多項式の係数 $w_i$.係数が大きすぎると, 値の上下の激しい,複雑なモデルになってしまうんでした.

そこで二乗誤差と多項式の係数を抑制するために, \begin{align} \tilde{E}(\bm{w})&=\dfrac{1}{2}\sum_{n=1}^N\bigl(y(x_n,\bm{w})-t_n\bigr)^2+\dfrac{\lambda}{2}\|\bm{w}\|^2\\ \text{ただし,}\quad\|\bm w\|^2&={w_0}^2+{w_1}^2+\cdots+{w_M}^2 \end{align} という値を最小化することを考えます. こうすると,たとえ次数が大きくとも $\|\bm w\|^2$ の値が大きくなりすぎないため, 比較的モデルの形が単純になります. 式に登場する $\lambda$ は重みづけパラメータで, 「二乗誤差を小さくする」ことと「データを単純化する」ことの, どちらをどの程度優先するか,その程度を表しています. $\lambda$が大きければ大きいほど,$\|\bm{w}\|^2$ を小さくすることが優先されます.

正則化を取り入れて訓練データから多項式フィッティングを行った結果は,次の通り.

正則化を取り入れたときの9次式フィッティング

正則化を取り入れたときの9次式フィッティング

上のグラフは $\lambda=e^{-18}$ としたときのもの.$\lambda$ を変えると結果も大きく変わる.

完璧にフィットしているわけではないものの,過学習が大幅に改善されていますね! ただし,ここでも $\lambda$ の値の選び方が重要になります. 小さすぎると過学習は改善されませんし,大きすぎてもモデルが単純化されすぎてうまくいきません. $\lambda$ の値の調整方法は,また別の機会にお話ししましょう.


今回の話は以上です. 機械学習一番目の記事にしては,かなり内容の詰まった話だったため, 説明を割愛した部分もたくさんあります. 詳しいことはまた後日記事にしようと思っていますから,是非ご期待ください!

(Part2. へ続く…)
26 件

関連する記事 こんな記事も人気です♪

【機械学習勉強会】パターン認識と機械学習(PRML)第1章のまとめ Part.2 ~モデル選択・次元の呪い~

【機械学習勉強会】パターン認識と機械学習(PRML)第1章のまとめ Part.2 ~モデル選択・次元の呪い~

機械学習理論の名著である『パターン認識と機械学習』(Pattern Recognition and Machine Learning, PRML)の内容をまとめます. 第1章の Part.2 では,「モデル選択」「次元の呪い」について説明します.
最上 伸一 | 10,285 view
【機械学習勉強会】パターン認識と機械学習(PRML)第1章のまとめ Part.3 ~決定理論~

【機械学習勉強会】パターン認識と機械学習(PRML)第1章のまとめ Part.3 ~決定理論~

機械学習理論の名著である『パターン認識と機械学習』(Pattern Recognition and Machine Learning, PRML)の内容をまとめます. 第1章の Part.3 では,「決定理論」について説明します.
最上 伸一 | 9,982 view
KaggleチュートリアルTitanicで上位3%以内に入るには。(0.82297)

KaggleチュートリアルTitanicで上位3%以内に入るには。(0.82297)

まだ機械学習の勉強を初めて4ヶ月ですが、色々やってみた結果、約7000人のうち200位ぐらいの0.82297という記録を出せたので、色々振り返りながら書いていきます。
Takumi Ihara | 185,801 view
LSTMとは〜概要と応用について〜

LSTMとは〜概要と応用について〜

音声信号処理や文章・対話の生成に用いられているLSTM(Long Short Term Memory)についてまとめました。
pythonによるtensorflow〜deepdreamによる画像変換〜

pythonによるtensorflow〜deepdreamによる画像変換〜

今回は前回のtensorflowの記事に引き続き、deepdreamによる画像変換についてご紹介します。

この記事のキーワード

この記事のキュレーター

最上 伸一 最上 伸一