損失関数の平坦性の種々の定義について

以前,損失関数の平坦性の定義について考えたことを雑記としてメモとして残す.

損失地形の平坦性

損失関数  \mathcal{L}(\theta) はモデルのパラメータ  \theta \in \mathbb{R}^d に対して定義されるスカラー値関数である.損失地形とは,パラメータ空間における幾何学的形状を指す.次のような最適化問題を解く:

\begin{align*}
\min_{\theta \in \mathbb{R}^d} \mathcal{L}(\theta) = \frac{1}{n}\sum_{i=1}^n l(f(x_i; \theta), y_i)
\end{align*}

ただし,  l(f, y) は二乗誤差関数やクロスエントロピー関数による損失であり,  f はモデルの出力を表す.

損失地形に関して重要な観点がいくつかあるが,今回は特に「平坦性」に注目する.

平坦性とは直感的に言えば,パラメータの摂動  \delta に対して損失関数が鋭敏に反応しない状態を指す.これを数式的に定式化すれば,次の差分

 \Delta \mathcal{L} = \mathcal{L}(\theta + \delta) - \mathcal{L}(\theta)

が十分に小さい場合,平坦であると言える.

また,損失関数のヘッセ行列  H(\theta)固有値が小さい場合にも平坦性があると見なされる.

これらが等しいことは自分には非自明だったので,理解のために式を書き下した.

 

以下数式

テイラー展開を用いて損失関数  \mathcal{L}(\theta) を展開する:

\begin{align*}
\mathcal{L}(\theta + \delta) &= \mathcal{L}(\theta) + \nabla_\theta \mathcal{L}(\theta)^T \delta + \frac{1}{2}\delta^T H(\theta) \delta + O(|\delta|^3) \\ &\approx \mathcal{L}(\theta) + \nabla_\theta \mathcal{L}(\theta)^T \delta + \frac{1}{2}\delta^T H(\theta) \delta
\end{align*}

ここで  \delta が十分小さいため,三次以上の項を無視した.

最適化により局所解  \theta において  \nabla_\theta \mathcal{L}(\theta) = 0 となるため,

\begin{align*}
\mathcal{L}(\theta + \delta) - \mathcal{L}(\theta) \approx \frac{1}{2}\delta^T H(\theta) \delta \end{align*}

が得られる.

 H(\theta) が対称行列であるため,固有値分解可能である.  Q固有ベクトルを列に持つ直交行列,  \Lambda = \text{diag}(\lambda_1, \lambda_2, \dots, \lambda_d)固有値の対角行列とすると,

 H(\theta) = Q \Lambda Q^T

と分解できる.このとき,  \delta = Qz と基底展開すれば,損失の変化は次のようになる:

\begin{align*}
    \Delta L&=\dfrac{1}{2}\delta^TH(\theta)\delta\\
    &=\dfrac{1}{2}z^TQ^TQ\Lambda Q^T Qz\\
    &=\dfrac{1}{2}z^T\Lambda z\\
    &=\dfrac{1}{2}\sum_{i}\lambda_i z_i^2
\end{align*}

ここで直交行列の性質  Q^T Q = I を利用した.  z_i^2固有ベクトルを基底としたときの  \deltaユークリッドノルムであり,損失の変化は固有値  \lambda_i とその方向の摂動の大きさ  z_i^2 に依存する.