PRML 7. Sparse Kernel Machines (Classification with SVM)

パターン認識と機械学習 上

パターン認識と機械学習 上

概要

  • 分類の境界線から一番違い要素との距離を最大化するように最適化を行う。
  • 2クラスの分類を扱うが、多クラスへの応用はストレートにできる。
  • 最大化問題をラグランジュ未定乗数法を用いて双対問題を作り、双対問題を解く。
  • 予測時に一部の学習データのみを用いる。それらをサポートベクターと呼ぶ。
  • 式変形は単純なので、メモなし。

図7.2の再現

  • Gaussianカーネルを用いる
  • パラメータ sigmaを決める必要があるが、最初から固定されていると仮定する。
  • sigma=0.2のときの結果は下図の通り。 f:id:nsb248:20170130220941p:plain

sigmaによる結果への影響について。

  • 今回、Gaussianカーネルを使っているので、sigmaは各学習用データがその周りに与える影響の大きさを表している。
  • 小さすぎるとサブセットの周りを囲むような境界線になってしまい、汎用性がさがる。 f:id:nsb248:20170130221741p:plain
  • 逆に大きくするとデータから遠いところが大雑把になりそう。 f:id:nsb248:20170130221922p:plain

外れ値

  • ナイーブに最大化問題を解くと、外れ値に対応できない。
  • アルゴリズム上、すべての学習データを正しく分類する。汎用性が低くなる可能性がある。 f:id:nsb248:20170130222145p:plain
  • そこで、スラック変数を導入して、誤って分類させることを許し、そのかわりにペナルティを与えるようにする。 f:id:nsb248:20170130222239p:plain
  • スラック変数導入前は、左下の青い一点を囲むような境界ができていたが、導入後は左下の青い点は間違った分類(赤)になったままになった。

Python

  • Python初心者なので、まだまだ汚いです
  • SVN自体は既存モジュールに実装されているが、勉強のため一部実装。
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize
import math


class DualRepresentation:
    def __init__(self, x, t, sigma, c):
        self.x = x
        self.t = t.flatten()
        self.sigma = sigma
        self.a = None
        self.b = 0
        self.c = c

    def kernel(self, xn, xm):
        return math.exp(- 0.5 * np.linalg.norm(xn - xm) ** 2 / (self.sigma ** 2))

    def __call__(self, input):
        n = self.t.shape[0]
        a = np.array(input).flatten()
        k = np.array([self.kernel(xn, xm) for xn in self.x for xm in self.x]).reshape(n, n)
        v = np.sum(a)
        for i in range(n):
            for j in range(n):
                v -= 0.5 * a[i] * a[j] * self.t[i] * self.t[j] * k[i][j]
        return -v

    def set(self, a):
        n = self.t.shape[0]
        self.a = a
        idx = np.where(a != 0)
        ns = idx[0].shape[0]
        k = np.array([self.kernel(xn, xm) for xn in self.x for xm in self.x]).reshape(n, n)
        b = 0.0
        for i in idx[0]:
            b += self.t[i]
            for j in idx[0]:
                b -= a[j] * self.t[j] * k[i][j]
        self.b = b / ns

    def learn(self):
        n = self.t.shape[0]
        a0 = np.random.uniform(0, 1, n)
        cons = ({'type': 'ineq', 'fun': lambda a: a},
                {'type': 'eq', 'fun': lambda a: np.sum(self.t * a) - 1},
                {'type': 'ineq', 'fun': lambda a: self.c - a})
        res = minimize(self, a0, method='SLSQP', tol=1e-6, constraints=cons)
        a = np.array([x if x > 1e-12 else 0.0 for x in res.x])
        self.set(a)
        return a

    def predict(self, x):
        n = self.t.shape[0]
        k = np.array([self.kernel(x, xm) for xm in self.x])
        return np.sum(self.a * self.t * k) + self.b


def plot72(x, t, sigma, c=1e+16, do_save=None):
    n = data.shape[0]

    f = DualRepresentation(data[:, 0:2], t, sigma, c)
    a = f.learn()

    n_plot = 100
    axis_x0 = np.linspace(np.min(x[:, 0]), np.max(x[:, 0]), n_plot)
    axis_x1 = np.linspace(np.min(x[:, 1]), np.max(x[:, 1]), n_plot)
    boundary = np.array([f.predict(np.array([x0, x1])) for x0 in axis_x0 for x1 in axis_x1]).reshape(n_plot, n_plot)

    idx = np.where(a != 0.0)
    plt.clf()
    plt.scatter(x[idx, 0], x[idx, 1], c='w', marker='o', s=100)
    plt.scatter(x[:, 0], x[:, 1], c=t, marker='+', s=60)
    plt.contour(axis_x0, axis_x1, boundary.T, np.array([-1, 0, 1]))
    if do_save is None:
        plt.show()
    else:
        plt.savefig(do_save)


if __name__ == '__main__':
    data = np.genfromtxt('dataset/classification_7_2.csv', delimiter=',').astype(np.float32)
    plot72(x=data[:, 0:2], t=data[:, 2], sigma=0.1, do_save='img/f7.2_0.1.png')
    plot72(x=data[:, 0:2], t=data[:, 2], sigma=0.2, do_save='img/f7.2_0.2.png')
    plot72(x=data[:, 0:2], t=data[:, 2], sigma=0.5, do_save='img/f7.2_0.5.png')

    data = np.vstack([data, np.array([[0.15, 0.35, -1]])])
    plot72(x=data[:, 0:2], t=data[:, 2], sigma=0.2, do_save='img/f7.2_0.2_outlier.png')
    plot72(x=data[:, 0:2], t=data[:, 2], sigma=0.2, c=1, do_save='img/f7.2_0.2_outlier_slack.png')

パターン認識と機械学習 上

パターン認識と機械学習 上