こんにちは!新しく記事を書かせていただくことになった,最上伸一です. 今まさに私が「学習」している最中の機械学習理論について,これから くわしくお話したいと思います.
機械学習やAI,最近何かと話題になっているから少し勉強してみたけど, 何が何やらさっぱり分からない…という人向けに, 機械学習のアレやコレを,なるべく分かりやすく解説していきます!
PRML とは
機械学習理論をまとめる上で軸にした本が,名著と名高い『パターン認識と機械学習』. 機械学習に興味があって色々調べていれば,ご存知の方も多いかと思います. 英語で ``Pattern Recognition and Machine Learning" だから,PRMLという通称で呼ばれています.
パターン認識と機械学習(上巻)
【機械学習の手順 ~第1章冒頭~】
そもそも,「機械学習」とは何でしょうか. また,機械学習を実行するということは,具体的にどのような手続きを指すのでしょうか.
私たち人間は,とても自然に学習を行っています.たとえば,ここに 手書き数字の「4」があったとしましょう. それが自分の書いた字ではなく,自分のクセとは多少違っていたとしても (たとえば,4の上をぴったり付けて書くか離して書くかなど), 私たちは「4」という数字だと認識することができます.
図1.1
同じことを機械が行うのは大変です. 4という数字の書き方のルールは,厳密に定まっていないからです. 無理にルールを教えようとしても,そのルールに当てはまらない「4」や, そのルールに当てはまってしまう「4でないもの」が沢山出てきてしまい, ルールと例外を果てしなく追加しなければなりません (これをルール数が発散すると表現しています).
人力で見つけ出された経験的な規則 (こういう発見的規則のことをヒューリスティックな規則といいます) では,機械にうまく伝わりません. そこで,機械が扱う規則性は,機械に発見してもらうことにします. これが機械学習です.
機械学習の目標は,ある入力ベクトルに対して,適切な出力を返せるようになること. 上の例では,「手書き文字」という入力ベクトルから,「数字」という出力を返すのが目標です. 何か入力が与えられたときに出力が定まるもののことを, 数学の世界では関数と言うのでした.まさに,
機械学習は多くの場合,「学ぶ」段階を経て, 「学んだことを活用する」段階に移ります. 学ぶ段階を訓練段階や学習段階といい, 学習に用いるデータを訓練データといいます. 学んだことを活用するときに用いるデータをテストデータといいます. 人間も,授業で学んだ(訓練した)ことを確かめるとき,テストを行いますよね. テストデータに用いられるものは,訓練データと全く同じものばかりではありません. 訓練と異なる新たな事例を分類する能力のことを,汎化といいます.
機械学習には大きく分けて3種類あります. 教師あり学習と教師なし学習, そしてその中間ぐらいの立場をとる強化学習です.
ここでいう「教師」とは,「正解」とか「お手本」ぐらいの意味でとらえると分かりやすいと思います. 先程の文字認識の例では,「これは4と認識してほしい」という正解があるため, 教師あり学習を行うのが妥当です.
【多項式を用いた回帰の話 ~1.1節~】
まずは代表的な機械学習の例として,回帰問題を考えましょう. 回帰とは,与えられた入力ベクトルから1つ以上の連続的な数値を予測することです. 機械学習的な考え方でいうと, 「予め与えられた入力ベクトル(訓練データ)から,目標値が従う規則性を見つけ出し, 目標値が分からない入力ベクトル(テストデータ)にその規則性に当てはめて,未知の目標値を予測する」 わけですから,教師あり学習を行うことになります.ここでいう教師は,目標値のことですね.
たとえば最高気温と湿度のデータから,あるアイスクリームの売り上げを予測したいとします. このとき入力ベクトルは「最高気温と湿度」,目標値はアイスクリームの売り上げです. そこから,たとえば気温と湿度が高ければアイスクリームが良く売れる,といった規則性を見つけ出し, およそどういう数式に従うのか予測するのが回帰です.
もう少し具体的に話をするために,人工的に作られた,次のデータを考えてみましょう.
図1.2
横軸は入力変数 $x$ であり,縦軸が予測したい目的変数 $t$.
緑の曲線は $y=\sin(2\pi x)$.
この青い点は,$y=\sin(2\pi x)$ にランダムなノイズを入れて作られています. もっとたくさんデータ点を取れば,青い点だけから $y=\sin(2\pi x)$ がうっすら見えてきますが, ここではたった10個の点から,$y=\sin(2\pi x)$ のような規則性をできるかぎり正確に予測することを目標にしましょう.
回帰問題で最も単純なのは線形回帰と呼ばれる,一次関数で近似する方法ですが, 今回予測したいデータは(青い点だけを見ても)ぐにゃぐにゃと折れ曲がっているように見えますから, ここは多項式曲線フィッティングを考えることにします.つまり,規則性を \begin{align} y(x,\bm{w})=w_0+w_1x+w_2x^2+\cdots+w_Mx^M \end{align} という形の式,$M$ 次関数で予想してみます. $\bm w$ というのはベクトル $(w_0,w_1,\ldots,w_M)$ のことです. 慣習的に,太字は基本的にベクトルを表します.
具体的にどのように予測するかですが,次の二乗誤差と呼ばれる値: \begin{align} E(\bm{w})=\dfrac{1}{2}\sum_{n=1}^N\bigl(y(x_n,\bm{w})-t_n\bigr)^2 \end{align} を最小化するように,うまく $\bm{w}$ の値を決めます. $y(x_n,\bm{w})-t_n$ というのは,真の値($t_n$)と予測値($y(x_n,\bm{w})$)とのズレ. 直観的には,「なるべくデータ点とのズレの総和が小さくなるような $\bm w$ を見つける」ということです. この $\bm w$ をどうやって見つけるかについては…今回は割愛します.
さて,ここで問題です. $M$ 次関数でフィッティングするとき,私たちはまず $M$ の値を決める必要があります. 次の3つのうち,$y=\sin(2\pi x)$ への当てはまりが最も良くなるのはどれでしょうか.
- 1次関数でフィッティング(つまり線形回帰)
- 3次関数でフィッティング
- 9次関数でフィッティング
皆さん,予想は出来ましたか?
百聞は一見に如かず,ということで, 実際に多項式をあてはめた結果を見てみましょう.
1次式で近似した場合
3次式で近似した場合
9次式で近似した場合
皆さんの予想はいかがでしたか?グラフで見てみると,一目瞭然ですね! 緑の曲線($y=\sin(2\pi x)$)を最もよく近似しているのは, 3つの選択肢の中では $M=3$ のとき,という結果になりました.
しかし,直感的には,「次数が高い方がよく近似できるだろう」と思いませんか. 調整できる多項式の係数 $w_i$ の数が増えるからです. また,「$M=9$ のグラフはたしかに緑の線からは外れているかもしれないが, 青いデータ点にはピッタリ誤差なく適合できている.それなら問題ないじゃないか」 と思うかもしれません.
しかし,先程述べた「汎化」という観点から考えると, 青い点(訓練データ)だけに特化して高い性能を発揮するのは,良いことではありません. 訓練データ以外のテストデータが来た時に,誤差が大きくなってしまうためです. 実際,次数ごとに多項式フィッティングを行うと,誤差は次のようになります.
次数ごとの訓練データ・テストデータの誤差
しかし,テストデータ(赤)の誤差は,途中までは誤差が減るものの,$M=9$のときは誤差が急激に大きくなる.
誤差は平均二乗平方根誤差 $E_{\text{RMS}}=\sqrt{2E(\bm w^*)/N}$.
$M$ が大きくなると自由度が高くなるため, 多少複雑なモデルでも柔軟に処理できます. ところが,自由度が高すぎると訓練データのノイズに引きずられ, 訓練データに合わせようと,係数 $w_i$ の値が非常に大きな値を取るようになります. その結果,テストデータにはとても合致しないような「変な」曲線を出すことになります. このような現象のことを過学習といいます. 機械学習において,過学習は避けるべき問題です.どうしたら避けられるのでしょうか.
過学習を避けるには
過学習を避ける最も単純な方法は,テストデータを増やすことです. 先程示した 9次式(自由度10) というのは,テスト点が10個しかなかったため, 相対的に見るとかなり多い自由度でした. しかし,テスト点を100個に増やせば,相対的に見れば 9次 というのは少ない自由度です. そのため,過学習のリスクを減らすことができます.
テスト点が100個のときの多項式フィッティング
もちろんテストデータがいつでも潤沢にあるとは限りませんから, その他にも過学習を避ける方法としてベイズ的なアプローチを取る方法, 正則化項を付ける方法などが考えられています. ベイズ的な方法については後に詳しくお話しするとして, 今回は正則化について少しだけ紹介しましょう. (とはいっても,ベイズ的方法も,正則化と大きくつながっていますが…)
多項式曲線フィッティングでは,二乗誤差 \begin{align} E(\bm{w})=\dfrac{1}{2}\sum_{n=1}^N\bigl(y(x_n,\bm{w})-t_n\bigr)^2 \end{align} を最小化するんでした. $E$ は小さければ小さいほど好ましい推定ということです.
ところで,今までの議論から他にも「大きすぎると好ましくないもの」があります. それは,多項式の係数 $w_i$.係数が大きすぎると, 値の上下の激しい,複雑なモデルになってしまうんでした.
そこで二乗誤差と多項式の係数を抑制するために, \begin{align} \tilde{E}(\bm{w})&=\dfrac{1}{2}\sum_{n=1}^N\bigl(y(x_n,\bm{w})-t_n\bigr)^2+\dfrac{\lambda}{2}\|\bm{w}\|^2\\ \text{ただし,}\quad\|\bm w\|^2&={w_0}^2+{w_1}^2+\cdots+{w_M}^2 \end{align} という値を最小化することを考えます. こうすると,たとえ次数が大きくとも $\|\bm w\|^2$ の値が大きくなりすぎないため, 比較的モデルの形が単純になります. 式に登場する $\lambda$ は重みづけパラメータで, 「二乗誤差を小さくする」ことと「データを単純化する」ことの, どちらをどの程度優先するか,その程度を表しています. $\lambda$が大きければ大きいほど,$\|\bm{w}\|^2$ を小さくすることが優先されます.
正則化を取り入れて訓練データから多項式フィッティングを行った結果は,次の通り.
正則化を取り入れたときの9次式フィッティング
完璧にフィットしているわけではないものの,過学習が大幅に改善されていますね! ただし,ここでも $\lambda$ の値の選び方が重要になります. 小さすぎると過学習は改善されませんし,大きすぎてもモデルが単純化されすぎてうまくいきません. $\lambda$ の値の調整方法は,また別の機会にお話ししましょう.
今回の話は以上です. 機械学習一番目の記事にしては,かなり内容の詰まった話だったため, 説明を割愛した部分もたくさんあります. 詳しいことはまた後日記事にしようと思っていますから,是非ご期待ください!
名著と名高い一方,「難しい」との声も多い.