皆さんの予想はいかがでしたか?グラフで見てみると,一目瞭然ですね! 緑の曲線($y=\sin(2\pi x)$)を最もよく近似しているのは, 3つの選択肢の中では $M=3$ のとき,という結果になりました.
しかし,直感的には,「次数が高い方がよく近似できるだろう」と思いませんか. 調整できる多項式の係数 $w_i$ の数が増えるからです. また,「$M=9$ のグラフはたしかに緑の線からは外れているかもしれないが, 青いデータ点にはピッタリ誤差なく適合できている.それなら問題ないじゃないか」 と思うかもしれません.
しかし,先程述べた「汎化」という観点から考えると, 青い点(訓練データ)だけに特化して高い性能を発揮するのは,良いことではありません. 訓練データ以外のテストデータが来た時に,誤差が大きくなってしまうためです. 実際,次数ごとに多項式フィッティングを行うと,誤差は次のようになります.
$M$ が大きくなると自由度が高くなるため, 多少複雑なモデルでも柔軟に処理できます. ところが,自由度が高すぎると訓練データのノイズに引きずられ, 訓練データに合わせようと,係数 $w_i$ の値が非常に大きな値を取るようになります. その結果,テストデータにはとても合致しないような「変な」曲線を出すことになります. このような現象のことを過学習といいます. 機械学習において,過学習は避けるべき問題です.どうしたら避けられるのでしょうか.
過学習を避けるには
過学習を避ける最も単純な方法は,テストデータを増やすことです. 先程示した 9次式(自由度10) というのは,テスト点が10個しかなかったため, 相対的に見るとかなり多い自由度でした. しかし,テスト点を100個に増やせば,相対的に見れば 9次 というのは少ない自由度です. そのため,過学習のリスクを減らすことができます.
もちろんテストデータがいつでも潤沢にあるとは限りませんから, その他にも過学習を避ける方法としてベイズ的なアプローチを取る方法, 正則化項を付ける方法などが考えられています. ベイズ的な方法については後に詳しくお話しするとして, 今回は正則化について少しだけ紹介しましょう. (とはいっても,ベイズ的方法も,正則化と大きくつながっていますが…)
多項式曲線フィッティングでは,二乗誤差 \begin{align} E(\bm{w})=\dfrac{1}{2}\sum_{n=1}^N\bigl(y(x_n,\bm{w})-t_n\bigr)^2 \end{align} を最小化するんでした. $E$ は小さければ小さいほど好ましい推定ということです.
ところで,今までの議論から他にも「大きすぎると好ましくないもの」があります. それは,多項式の係数 $w_i$.係数が大きすぎると, 値の上下の激しい,複雑なモデルになってしまうんでした.
そこで二乗誤差と多項式の係数を抑制するために, \begin{align} \tilde{E}(\bm{w})&=\dfrac{1}{2}\sum_{n=1}^N\bigl(y(x_n,\bm{w})-t_n\bigr)^2+\dfrac{\lambda}{2}\|\bm{w}\|^2\\ \text{ただし,}\quad\|\bm w\|^2&={w_0}^2+{w_1}^2+\cdots+{w_M}^2 \end{align} という値を最小化することを考えます. こうすると,たとえ次数が大きくとも $\|\bm w\|^2$ の値が大きくなりすぎないため, 比較的モデルの形が単純になります. 式に登場する $\lambda$ は重みづけパラメータで, 「二乗誤差を小さくする」ことと「データを単純化する」ことの, どちらをどの程度優先するか,その程度を表しています. $\lambda$が大きければ大きいほど,$\|\bm{w}\|^2$ を小さくすることが優先されます.
正則化を取り入れて訓練データから多項式フィッティングを行った結果は,次の通り.
完璧にフィットしているわけではないものの,過学習が大幅に改善されていますね! ただし,ここでも $\lambda$ の値の選び方が重要になります. 小さすぎると過学習は改善されませんし,大きすぎてもモデルが単純化されすぎてうまくいきません. $\lambda$ の値の調整方法は,また別の機会にお話ししましょう.
今回の話は以上です. 機械学習一番目の記事にしては,かなり内容の詰まった話だったため, 説明を割愛した部分もたくさんあります. 詳しいことはまた後日記事にしようと思っていますから,是非ご期待ください!
しかし,テストデータ(赤)の誤差は,途中までは誤差が減るものの,$M=9$のときは誤差が急激に大きくなる.
誤差は平均二乗平方根誤差 $E_{\text{RMS}}=\sqrt{2E(\bm w^*)/N}$.