1.1 Example: Polynomial Curve Fitting
์ ์ผ ๊ฐ๋จํ ํ๊ท๋ฌธ์ ๋ฅผ ์์๋ก ๋ ๋ค.
์ค์ ์ ๋ ฅ๋ณ์ ๋ก ์ค์ ํ๊ฒ๋ณ์ ๋ฅผ ์์ธกํ๋ ๋ฌธ์ ๋ค.
ํ๋ จ ๋ฐ์ดํฐ ์ดํด๋ณด๊ธฐ
๊ฐ๋ น 10๊ฐ์ ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ๋ง๋ค์ด๋ณด๋๋ฐ, ์ ๋ ฅ๋ณ์ ๋ 0๊ณผ 1์ฌ์ด์ ์ค์, ํ๊ฒ๋ณ์ ๋ ์์ ๊ฐ์ฐ์์ ๋ถํฌ์์ ์ํ๋งํ ์์ ๋ ธ์ด์ฆ(Noise) ๋ฅผ ์ค์ ์ฝ๊ฐ์ ๋ณํ์ ๊ฐํ๋ค.
์ฐ๋ฆฌ๊ฐ ์๊ณ ์ถ์ดํ๋ ํจ์๋ ์ธ์์ ์ง๋ฆฌ ํน์ ์๋ฆฌ๋ผ๊ณ ์๊ฐํ ์ ์๋ค. ํ์ง๋ง ์ค์ ์ธ์์ ๊ฐ๋ณ ๋ฐ์ดํฐ๋ ์ฐ๋ฆฌ๊ฐ ์ ์ ์๋ ์ด๋ค ์์๋ค๋ก ์ธํด์ ์ง์ค์ ์ ์ ์๊ฒ ๋์ด์๋ ๊ฒฝ์ฐ๊ฐ ๋ง๋ค. ์ด๋ฅผ ๋ ธ์ด์ฆ๊ฐ ๋ค์ด๊ฐ ๋ฐ์ดํฐ๋ก ํํํ ๊ฒ์ด๋ค. ์ฆ, ์ฝ๊ฒ ๋งํ๋ฉด ์ ๋ ฅ๋ฐ์ดํฐ ๋ ํจ์๋ฅผ ํต๊ณผํด์ ์ ๋ต์ด ๊ฐ ๋์์ผํ๋๋ฐ, ๊ด์ธก๋๋ ๋ฐ์ดํฐ๋ ํญ์ ์ด๋ค ์์์ ์ํด์ ์กฐ๊ธ์ฉ ๋ฌ๋ผ์ ธ์ ๊ด์ธก๋๋ค๋ ๋ง์ด๋ค.
์ด์ ์ฐ๋ฆฌ์ ๋ชฉํ๋ ๊ด์ธก๋ ํ๋ จ๋ฐ์ดํฐ๋ค์ ์ฌ์ฉํด์ ์๋ก์ด ์ ๋ ฅ๋ณ์ ๊ฐ ๋ค์ด์์ ๋ ํ๊ฒ๋ณ์ ๋ฅผ ์์ธกํ๋ ๊ฒ์ด๋ค.
import numpy as np
import matplotlib.pylab as plt# making data
seed = 62
np.random.seed(seed)
N = 10
x = np.random.rand(N)
t = np.sin(2*np.pi*x) + np.random.randn(N) * 0.1
x_sin = np.linspace(0, 1)
t_sin = np.sin(2*np.pi*x_sin)
plt.plot(x_sin, t_sin, c='green')
plt.scatter(x, t)
plt.xlabel('x', fontsize=16)
plt.ylabel('t', rotation=0, fontsize=16)
plt.show()ํ๋ฅ ์ด๋ก (Probability theory) ์ ๋ถํ์ค์ฑ์ ์ ํํ๊ณ ์์ ์ธ ๋ฐฉ์์ผ๋ก ์ธก์ ํ ์ ์๋ ํ๋์ ํ๋ ์์ํฌ๋ค. ๊ฒฐ์ ์ด๋ก (Decision theory) ์ ํ๋ฅ ๋ก ํํ ๊ฒ๋ค์ ์์ธก(๊ฒฐ์ ) ํ ๋, ์ ์ ํ ์ฒ๋๋ฅผ ๊ฐ์ง๊ณ ์ด๋ค์ ํฉ๋ฆฌ์ ์ผ๋ก ์ต์ ํ ํ๋ ์ด๋ก ์ด๋ค.
๋คํญ ํจ์
๋ถํ์ค์ฑ์ ํด์ํ๊ธฐ์ํด ์ ๋ ๊ฐ์ง ์ด๋ก ์ ์ฌ์ฉํ ์ ์์ง๋ง, ์ฌ๊ธฐ์๋ ์ฐ์ ๋คํญ ํจ์(polynomial fucntion) ๋ฅผ ํตํด ์ ๊ทผํด๋ณผ ์ ์๋ค. ์ ๋คํญ ํจ์์ ์ฐจ์(degree) ๋ผ๊ณ ํ๋ฉฐ, ๊ทธ ์์์ ์ต๊ณ ์ ์ฐจ์๋ฅผ ๊ฐ๋ฅดํจ๋ค. ์ ์ดํด๋ณด๋ฉด, ๋คํญ ํจ์()์ ๊ณ์ ์ ์ฐ๊ด๋ ์ ํ ํ๊ท ์ด๋ค.
๋ค์ ์ ๊น ์ ๋ฆฌํด์, ์ง๊ธ ํ๋ ๊ฒ์ ์ธ์์ ์ง๋ฆฌ() ๋ฅผ ๋ชจ๋ฅธ๋ค๊ณ ์๊ฐํ๊ณ ๋คํญ ํจ์๋ฅผ ํตํด์ ์ด๊ฒ์ด ๊ด์ธก๋ ๋ฐ์ดํฐ์ ์ง๋ฆฌ๊ฐ ์๋๊น ํ๊ณ ์์ธกํด๋ณด๋ ๊ฒ์ด๋ผ๊ณ ํ ์ ์๋ค. ์ด ์ ํ ํ๊ท์ ๊ณ์ ๋ ๊ด์ธก๋ ํ๋ จ๋ฐ์ดํฐ๋ก ๋ถํฐ ๋์ถํ ๊ฒ์ด๋ค. ๊ทธ๋ฌ๋ฉด ์ด๋ป๊ฒ ๋์ถํ ๊ฒ์ธ๊ฐ?
์ฐ๋ฆฌ๋ ๋คํญ ํจ์๋ฅผ ํตํด ์์ธก๋ ํ๊ฒ๊ณผ ์ค์ ํ๊ฒ๋ณ์์ ์ฐจ์ด๋ฅผ ๊ตฌํด, ์ผ๋งํผ ํ๋ ธ๋์ง(misfit)๋ฅผ ์ธก์ ํด๋ณผ ์ ์๋ค. ์ด๋ฅผ ๋ชฉ์ ํจ์(object function) / ์์ค ํจ์(error/loss function) ๋ผ๊ณ ํ๋ฉฐ, ์ด ์์คํจ์๋ฅผ ์ค์์ผ๋ก์จ ๊ณ์๋ฅผ ๊ตฌํ๋ ๋ชฉ์ ์ ๋ฌ์ฑํ ์ ์๋ค.
์ฌ๊ธฐ์๋ ๋ณดํต ๋ง์ด ์ฐ์ด๋ ๋ชฉ์ ํจ์๋ก MSE(Mean Square Error) ๋ฅผ ์ฌ์ฉํ๋ค.
def error_function(pred, target):
"""MSE function"""
return (1/2)*((pred-target)**2).sum()Python Code Solution for Polynomial
์ฐ์ ํฌ๊ธฐ์ ๋ฐฉ๋ฐ๋ฅด๋ชฝ๋ ํ๋ ฌ(Vandermode matrix) ๋ฅผ ์ ์ํ๊ณ ์ด๋ฅผ ๋ผ๊ณ ํ๋ค. ์์์๋ ๋งํ๋ฏ์ด ์ ๋คํญ ํจ์์ ์ฐจ์(degree) ๋ค.
def vandermonde_matrix(x, m):
"""vandermonde matrix"""
return np.array([x**i for i in range(m+1)]).T3 ์ฐจ ๋คํญ ํจ์์ ๋ฐฉ๋ฐ๋ฅด๋ชฝ๋ ํ๋ ฌ์ ์ดํด๋ณด์.
M = 3
V = vandermonde_matrix(x, M)
print(V.round(3))
# -----print result-----
# [[1. 0.034 0.001 0. ]
# [1. 0.489 0.239 0.117]
# [1. 0.846 0.716 0.606]
# [1. 0.411 0.169 0.07 ]
# [1. 0.631 0.399 0.252]
# [1. 0.291 0.085 0.025]
# [1. 0.543 0.295 0.16 ]
# [1. 0.228 0.052 0.012]
# [1. 0.24 0.058 0.014]
# [1. 0.953 0.909 0.867]]์ด์ ํ๋ ฌ๋ก ๋คํญํจ์ ์ (1) ์ ํํํ ์ ์๊ฒ ๋๋๋ฐ, ์๋์ ๊ฐ๋ค.
def polynomial_function(x, w, m):
assert w.size == m+1, "coefficients number must same as M+1"
V = vandermonde_matrix(x, m) # shape (x.size, M+1)
return np.dot(V, w)์์์ ๊ณ์๋ฅผ ์ด๊ธฐํ ์ํค๊ณ ๋คํญ ํจ์ ๊ฐ์ ์ดํด๋ณธ๋ค.
np.random.seed(seed)
w = np.random.randn(M+1)
t_hat = polynomial_function(x, w, M)
print(t_hat.round(3))
# -----print result-----
# [-0.03 -0.208 0.016 -0.2 -0.177 -0.162 -0.204 -0.134 -0.14 0.197]๊ทธ๋ฆฌ๊ณ ์์์ ์ ์ํ ์์ค ํจ์๋ฅผ ๋ค์ ํ๋ ฌ์ ๋ง๊ฒ ๋ฐ๊ฟ๋ณด๊ณ , ์กฐ๊ธ ๋ ๊ฐํธํ๊ฒ ํ๊ธฐ ์ํด์ ์์ฐจ(residual) ๋ฅผ ์ ์ํด์ ๋ค์ ๋ฐ๊ฟ๋ณธ๋ค.
์ฐ๋ฆฌ ๋ชฉ์ ์ ์์ค ํจ์์ ์ต๋ํ ์ค์ฌ์, ์ฆ ์ต์๊ฐ์ ๊ตฌํด์ ๊ณ์๋ฅผ ๊ตฌํ ๊ฒ์ด๋ค. . ๋ํ, ์์ค ํจ์๋ 2์ฐจ ํจ์์ด๊ธฐ ๋๋ฌธ์ 1์ฐจ ๋ฏธ๋ถ์ด 0์ผ ๋, ์ ์ผํ ํด๊ฐ ์กด์ฌํ๋ค. ๋ฐ๋ผ์ ๋ฏธ๋ถ์ ์ฐ์๋ฒ์น(chain rule)์ผ๋ก ์๋ ์ฒ๋ผ ๋ฏธ๋ถ์ ์งํ ํ ์ ์๋ค.
(3) ๋ฒ์ ์์์ ์์ชฝ์ ํ๋ ฌ์ ๋ฏธ๋ถํด ๋ณด๋ฉด ๋ฐฉ๋ฐ๋ฅด๋ชฝ๋์ ์ ์น ํ๋ ฌ์์ ์ ์ ์๋ค.
์ต์ข ์ ์ผ๋ก ํด๋ฅผ ๊ตฌํ ์ ์๋๋ฐ, ์๋์ ๊ฐ๋ค.
def poly_solution(x, t, m):
V = vandermonde_matrix(x, m)
return np.linalg.inv(np.dot(V.T, V)).dot(V.T).dot(t)์ด์ ๊ณ์๋ฅผ ๊ตฌํด๋ณธ๋ค.
print(f"Solution of coefficients are {poly_solution(x, t, M).round(3)}")
# -----print result-----
# Solution of coefficients are [ -0.245 11.722 -33.194 21.798]์ฌ์ค numpy ์์๋ ๋ ๊ฐํธํ ๊ธฐ๋ฅ์ ์ ๊ณตํ๊ณ ์๋ค.
from numpy.polynomial import polynomial as P
print(P.polyfit(x, t, M).round(3))
# -----print result-----
# [ -0.245 11.722 -33.194 21.798]์ต์ ์ ์ฐจ์(degree) ์ฐพ๊ธฐ
์ต์ ์ ๊ณ์๋ฅผ ์ฐพ๋ ๋ฌธ์ ๋ ํด๊ฒฐ๋์์ผ๋, ์ด์ ์ฐ๋ฆฌ์๊ฒ ๋จ์ ๋ฌธ์ ๋ ์ต์ ์ ์ฐจ์๋ฅผ ์ฐพ๋ ๊ฒ์ด๋ค. ๋คํญ ํจ์์ ์ฐจ์๋ ์ฐ๋ฆฌ๊ฐ ๋ง์๋๋ก ์ ํ ์ ์๋ค. ํ์ง๋ง ์ง๋ฆฌ()์ ๊ฐ์ฅ ๊ฐ๊น๊ฒ ๋ง๋๋ ์ต์ ์ ๊ณ์๋ ๋ฌด์์ธ๊ฐ? ์ด๋ฅผ ์ฐพ๋ ๊ณผ์ ์ ๋ชจ๋ธ ๋น๊ต(model comparison) ํน์ ๋ชจ๋ธ ์ ํ(model selection) ์ด๋ผ๊ณ ํ๋ค. ๋ํ ์ฐจ์ ์ฒ๋ผ์ฌ๋์ด ์์์ ์ผ๋ก ์กฐ์ ํ ์ ์๋ ๋ณ์๋ฅผ ํ์ดํผํ๋ผ๋ฏธํฐ(hyperparameter) ๋ผ๊ณ ํ๋ค.
์ด์ ์ฅ(Introduction์์ ์ฐ๋ฆฌ๋ ์ผ๋ฐํ(generalization) ์ด ํจํด์ธ์์ ์ฃผ์ ๋ชฉ์ ์ด๋ผ๊ณ ํ๋ค. ์ข์ ์ผ๋ฐํ๋ ์ผ๋งํผ ์ง๋ฆฌ์ ๊ฐ๊น์ด ํํ๋ ฅ์ ๋ณด์ด๋๊ฐ๋ก ์ธก์ ํ ์ ์๋ค. ์ฆ, ์ฌ๊ธฐ์๋ ์๋ก์ด ๋ฐ์ดํฐ๊ฐ ๋ค์ด์์ ๋, ์ผ๋งํผ ์ ์ ๊ทผํ ๊ฐ? ๋ฅผ ๋ณด๋ฉด ๋๋ค. ๊ทธ๋ ๋ค๋ฉด ์ด๋ฅผ ์ด๋ป๊ฒ ์ธก์ ํ ๊ฒ์ธ๊ฐ?
์ธก์ ์ ์ํด์ 100๊ฐ์ ๋ฐ์ดํฐ๋ฅผ ์ถ๊ฐ๋ก ์ํ๋งํด์ ์๋ก์ด ๋ฐ์ดํฐ๋ฅผ ๋ง๋ค์ด ํ ์คํธ ์ธํธ๋ก ๊ตฌ์ฑํ๋ค.
np.random.seed(seed)
N_test = 100
x_test = np.random.rand(N_test)
t_test = np.sin(2*np.pi*x_test) + np.random.randn(N_test) * 0.1
plt.plot(x_sin, t_sin, c='green')
plt.scatter(x_test, t_test, c='red')
plt.xlabel('x', fontsize=16)
plt.ylabel('t', rotation=0, fontsize=16)
plt.show()๊ทธ๋ฆฌ๊ณ ๋งค ๋ฒ ์ฐจ์()๋ฅผ ์ ํํ ๋ ๋ง๋ค, ํ๋ จ ์ธํธ์์ ์ต์ ํ๋ ๊ณ์๋ฅผ ๊ตฌํ๊ณ , ์ด ๊ณ์๋ฅผ ์ฌ์ฉํ์ฌ ์์ค ๊ฐ์ ์์ฐจ๋ฅผ ํ๋ จ ์ธํธ์ ํ ์คํธ ์ธํธ์ ๊ฐ๊ฐ ์ ์ฉํด์ ๊ตฌํ๋ค. ๊ทธ ๋ฐฉ๋ฒ ์ค์ ํ๋๋ RMS error(root-mean-sqruare error) ๋ผ๋ ๋ฐฉ๋ฒ์ผ๋ก ๊ตฌํ๋๋ฐ, ์์ ์๋์ ๊ฐ๋ค.
def root_mean_square_error(error, n_samples):
return np.sqrt(2*error/n_samples)๋ ์ฐจ์์์ ์ต์ ์ ๊ณ์, ์ ๋ฐ์ดํฐ์ ๊ฐฏ์๋ค. ์ ๋๋ ์ค ์ด์ ๋ ๋น๊ต๊ฐ๋ฅ๋๋ก ํฌ๊ธฐ๊ฐ ๋ค๋ฅธ ๋ฐ์ดํฐ ์ ์ ๋๋ฑํ ํฌ๊ธฐ๋ก ์ค์ผ์ผ๋ง ํ ๊ฒ์ด๋ค. ์์คํจ์๊ฐ ์ ๊ณฑ์ ์ทจํ๊ธฐ ๋๋ฌธ์ ์์ธกํ ๋ณ์์ ํ๊ฒ๋ณ์์ ์ฐจ์ด๊ฐ ํด์๋ก ๊ฐ์ด ๋ ์ปค์ง๋ ํ์์ด ์๋๋ฐ, ๋ฃจํธ ์ฐ์ฐ์ ์ทจํด์ค์ผ๋ก์จ, ํ๊ฒ ๋ณ์์ ๊ฐ์ ํฌ๊ธฐ์ ์ค์ผ์ผ๋ก ๋ค์ ๋ง์ถฐ์ง๋ค.
์ค๋ฒํผํ
(over-fitting)
์ด์ ์ธก์ ํ ๋ฐฉ๋ฒ๊น์ง ์๊ฒผ์ผ๋ ์ต์ ์ ์ฐจ์๋ฅผ ๊ณจ๋ผ๋ณด์. ์ฐจ์๊ฐ 0 ๋ถํฐ 9๊น์ง ๋ฃจํ๋ฌธ์ ๋๋ฉด์ ํ๋ จ ์ธํธ์ ํ ์คํธ ์ธํธ์ RMS error ๋ฅผ ์ธก์ ํด๋ณธ๋ค.
def get_rms_error(t_hat, t, n_sample, m):
error = error_function(t_hat, t)
rms = root_mean_square_error(error, n_sample)
return rms
all_w = []
all_rms_train = []
all_rms_test = []
for m in range(10):
optimal_w = poly_solution(x, t, m)
t_hat = polynomial_function(x, optimal_w, m)
t_hat_test = polynomial_function(x_test, optimal_w, m)
rms_train = get_rms_error(t_hat, t, N, m) # N=10
rms_test = get_rms_error(t_hat_test, t_test, N_test, m) # N_test = 100
print(f"M={m} | rms_train: {rms_train:.4f} rms_test: {rms_test:.4f}")
# Plot predicted line
plt.plot(x_sin, t_sin, c="green", label="sin function")
plt.plot(x_sin, polynomial_function(x_sin, optimal_w, m), c="red", label=f"model M={m}")
plt.scatter(x, t)
plt.xlim((0, 1))
plt.ylim((-1.25, 1.25))
plt.xlabel('x', fontsize=16)
plt.ylabel('t', rotation=0, fontsize=16)
plt.legend()
plt.show()
all_w.append(optimal_w)
all_rms_train.append(rms_train)
all_rms_test.append(rms_test)๊ทธ์ค ์ฐจ์๊ฐ 0, 1, 3, 9 ์ธ ๊ฒฝ์ฐ๋ฅผ ์ดํด๋ณธ๋ค.
# M=0 | rms_train: 0.6353 rms_test: 0.7221# M=1 | rms_train: 0.4227 rms_test: 0.4508# M=3 | rms_train: 0.0930 rms_test: 0.1238# M=9 | rms_train: 0.0872 rms_test: 19.2855์ฐ๋ฆฌ์ ์ง๋ฆฌ์ธ ์ ๊ฐ์ฅ ๊ฐ๊น์ด ๊ณก์ ์ ์ผ๋์ ๊ณก์ ์ด๋ค. ๋ค๋ฅธ ์ฐจ์์์๋ ์ข์ง ์์ ํํ๋ ฅ(์ง๋ฆฌ ํจ์์ ๋ชจ์์ด ๋น์ทํ์ง ์์)์ ๊ฐ์ง๊ณ ์๋๋ฐ, ํนํ ์ผ ๋๋ฅผ ์ดํด๋ณด๋ฉด ์ฐ๋ฆฌ์ ํ๋ จ๋ฐ์ดํฐ๋ฅผ ๊ดํตํ๋ ์์ฃผ ์ ํํ ์ผ์น์ฑ์ ๋ณด์ธ๋ค. ํ์ง๋ง ๊ณก์ ์ ๊ทธ๋ ค๋ณด๋ฉด ์ฒ์ฅ๊ณผ ๋ฐ๋ฐ๋ฅ์ ๋ซ๋ ๊ฒฝ์ฐ๊ฐ ๋ฐ์ํ๋๋ฐ, ์์ฃผ ๋์ ํํ๋ ฅ์ ๊ฐ์ง๊ณ ์๋ค๋ ๋ป์ด๋ค. ๋ณดํต ์ด๋ฌํ ํ์์ ๊ณผ๋์ ํฉ(over-fitting) ์ด๋ผ๊ณ ํ๋ค. ํน์ ์ RMS error ๊ฐ ์๋์ ์ผ๋ก ์์ฒญ ํฌ์ง ์์ง๋ง, ํด๋น ํจ์๋ก๋ ํจ์๋ฅผ ํํํด๋ด๊ธฐ์๋ ์ญ๋ถ์กฑํด ๋ณด์ธ๋ค. ์ด๋ฌํ ํ์์ ๊ณผ์์ ํฉ(under-fitting) ์ด๋ผ๊ณ ํ๋ค.
์ฐจ์์ ์ ํ์ ๋ฐ๋ฅธ RMS error ์ ๊ทธ๋ ค๋ณด๋ฉด ์๋์ ๊ฐ๋ค.
plt.scatter(np.arange(10), all_rms_train, facecolors='none', edgecolors='b')
plt.plot(np.arange(10), all_rms_train, c='b', label='Training')
plt.scatter(np.arange(len(all_rms_test)), all_rms_test, facecolors='none', edgecolors='r')
plt.plot(np.arange(len(all_rms_test)), all_rms_test, c='r', label='Test')
plt.legend()
plt.xlim((-0.1, 10))
plt.ylim((-0.1, 1.2))
plt.ylabel("root-mean-squared Error", fontsize=16)
plt.xlabel("M", fontsize=16)
plt.show()ํ๋ จ ์ธํธ์ ํ ์คํธ ์ธํธ ๊ฐ์ RMS error ์ฐจ์ด๊ฐ ์ฒ์์๋ ์ค์ด๋ค๋ค๊ฐ ๋์ค์๋ ์ปค์ง๋ ๊ฒ์ ์ ์ ์๋ค. ์ฆ ์ต์ ์ ์ฐจ์๋ ํ๋ จ ์ธํธ์ ํ ์คํธ ์ธํธ ๊ฐ์ ์ธก์ ์ฒ๋๊ฐ ์์ผ๋ฉฐ, ํ ์คํธ ์ธํธ์์ ์ด๋์ ๋ ๋ฎ์ ์ธก์ ์ฒ๋๋ฅผ ๊ฐ์ง๊ณ ์์ด์ผ ํ๋ค๋ ๊ฒ์ ์ ์ ์๋ค.
์ด์ ๊ฐ๊ฐ์ ๊ณ์๋ฅผ ์ถ๋ ฅํด๋ณธ๋ค. ์ฐจ์ ๊ฐ ์ปค์ง ์๋ก ๊ณ์๊ฐ ์ปค์ง๋ ๊ฒ์ ํ์ธ ํ ์ ์๋ค.
np.set_printoptions(precision=3)
for i in [0, 1, 3, 9]:
print(f"coefficients at M={i} is {all_w[i]}")
# -----print result-----
# coefficients at M=0 is [0.149]
# coefficients at M=1 is [ 0.961 -1.739]
# coefficients at M=3 is [ -0.245 11.722 -33.194 21.798]
# coefficients at M=9 is [-5.400e+01 2.606e+03 -3.763e+04 2.686e+05 -1.111e+06 2.839e+06 -4.546e+06 4.438e+06 -2.411e+06 5.575e+05]์ค๋ฒํผํ
์ ํผํ๋ ๋ฐฉ๋ฒ
๊ณผ๋์ ํฉ(over-fitting)์ ํผํ๋ ๋ฐฉ๋ฒ์ ๋ฌด์์ผ๊น? ๋ณต์กํ ๋ชจ๋ธ์ผ ์๋ก ๋ฐ์ดํฐ๊ฐ ๋ง์ผ๋ฉด ์ค๋ฒํผํ ์ ํผํด๊ฐ ์ ์๋ค. ์๋์ ์์๋ฅผ ๋ณด์. ๊ฐ์ ์ฐจ์(M=9)์ ๋ชจ๋ธ๋ก ํ๋ จ ๋ฐ์ดํฐ 15 ๊ฐ์ 100๊ฐ์ ์ฐจ์ด๋ก ํ์ต๋ ๊ณก์ ์ด ๋ฌ๋ผ์ก์์ ์ ์ ์๋ค.
np.random.seed(seed)
N1 = 15
N2 = 100
x1, x2 = np.random.rand(N1), np.random.rand(N2)
t1 = np.sin(2*np.pi*x1) + np.random.randn(N1) * 0.1
t2 = np.sin(2*np.pi*x2) + np.random.randn(N2) * 0.1
optimal_w1 = poly_solution(x1, t1, m=9)
optimal_w2 = poly_solution(x2, t2, m=9)
# Plot predicted line
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
def plot(x, t, x_sin, t_sin, optimal_w, m, ax):
ax.plot(x_sin, t_sin, c="green", label="sin function")
ax.plot(x_sin, polynomial_function(x_sin, optimal_w, m), c="red", label=f"model N={len(x)}")
ax.scatter(x, t)
ax.set_xlim((0, 1))
ax.set_ylim((-1.25, 1.25))
ax.set_xlabel('x', fontsize=16)
ax.set_ylabel('t', rotation=0, fontsize=16)
ax.legend()
plot(x1, t1, x_sin, t_sin, optimal_w1, m=9, ax=ax1)
plot(x2, t2, x_sin, t_sin, optimal_w2, m=9, ax=ax2)
plt.show()์ฆ, ๋ฐ์ดํฐ๊ฐ ๋ง์ ์ง ์๋ก ์ค๋ฒํผํ ๋ฌธ์ ๋ ์ ์ด์ง๋ค. ๋ ๋ค๋ฅธ ๋ง๋ก ํด์ํ๋ฉด, ํฐ ๋ฐ์ดํฐ ์ธํธ์ผ ์๋ก ๋ ๋ณต์กํ(์ ์ฐํ) ๋ชจ๋ธ์ ๋ง๋ค ์ ์๋ค.
๋ณต์กํ ๋ฌธ์ ๋ฅผ ํ๋ ค๋ฉด, ๋ ๋ณต์กํ ๋ชจ๋ธ์ ๋ง๋ค์ด์ผ ํ๋ค๋ ์๊ฐ์ด ์ด์ฏค๋๋ฉด ์๊ธธ ๊ฒ์ด๋ค. ์ฐจํ์ ์ฐ๋ฆฌ๋ ์ต๋ ๊ฐ๋ฅ๋(mamximum likelihood) ๋ฅผ ํตํด์ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ(๊ณ์)๋ฅผ ์ฐพ๋ ๋ฐฉ๋ฒ์ ๋ฐฐ์ธ ๊ฒ์ด๊ณ , ์ค๋ฒํผํ ๋ฌธ์ ๋ํ ์ต๋ ๊ฐ๋ฅ๋์ ํ ์ผ๋ฐ์ ์ธ ํน์ฑ์ผ๋ก ์ดํดํ ๊ฒ์ด๋ค. ๊ทธ๋ฆฌ๊ณ ๋ฒ ์ด์ง์(Bayesian) ์ ๊ทผ๋ฒ ์ ํตํด ์ค๋ฒํผํ ์ ํด์ํ ์ ์๋ค๋ ๊ฒ๋ ๋ฐฐ์ธ ๊ฒ์ด๋ค.
์ ๊ทํ(Regularization)
์์ ๋ฐฉ๋ฒ์ ๋ฐฐ์ฐ๊ธฐ ์ ์ ์ฐ์ ์ ๊ทํ(regularization) ์ด๋ผ๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ์์๋ณธ๋ค. ์ ๊ทํ๋ ๊ณ์๊ฐ ๋ ์ปค์ง์ง ์๋๋ก ์์ค ํจ์์ ํจ๋ํฐ(penalty)๋ฅผ ๋ํ๋ ๋ฐฉ๋ฒ์ด๋ค. (2) ๋ฒ์์ ์๋์ ๊ฐ์ด ๊ณ ์ณ๋ณธ๋ค.
์ฌ๊ธฐ์ ๋ค. ์ ๊ทํ ๊ณ์ ๋ ์ถ๊ฐ์ ์ผ๋ก ์ ์ฝ์กฐ๊ฑด์ ๋น์ค์ ์กฐ์ ํ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ค.
์ฌ๋ฌ๊ฐ์ง ์ ๊ทํ ๋ฐฉ๋ฒ์ด ์์ผ๋, ์ฌ๊ธฐ์๋ ์ ์ผ ๊ฐ๋จํ ๊ณ์์ ์ ๊ณฑ์ ์์คํจ์์ ๋ํด์ฃผ๋ ํ์์ผ๋ก ํจ๋ํฐ๋ฅผ ๋ํ๋ค. (4) ์์ ํด๋ฅผ ๊ตฌํ๋ ๊ฒ์ ๊ฐ๋จํ๋ค.
def ridge_solution(x, t, m, alpha=0):
V = vandermonde_matrix(x, m)
return np.linalg.inv(np.dot(V.T, V) - alpha * np.eye(m+1)).dot(V.T).dot(t)์ ๊ทํ ๊ณ์์ ํจ๊ณผ๋ ์๋์ ๊ทธ๋ฆผ์ ๋ณด๋ฉด ๊ทน๋ช ํ๋ค. ์ ๊ทํ ๊ณ์๊ฐ ํด์๋ก ๊ณ์ ๋ฅผ ๊ฐ๋ ฅํ๊ฒ ๊ท์ ํ๋ฉฐ ๋์ด์ ์ปค์ง์ง ๋ชปํ๊ฒ ํ๋ค. ๋ํ, ๊ทธ๋ฆผ์์ ๋ณผ ์ ์๋ฏ์ด ๋ชจ๋ธ์ ๋ณต์ก์ฑ์ ์ค์ฌ์ฃผ๊ณ ๊ณผ๋์ ํฉ์ ๋ง์์ค๋ค.
M=9
optimal_w1 = ridge_solution(x, t, m=M, alpha=1e-8)
optimal_w2 = ridge_solution(x, t, m=M, alpha=1.0)
# Plot predicted line
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
def plot_ridge(x, t, x_sin, t_sin, optimal_w, m, text, ax):
ax.plot(x_sin, t_sin, c="green", label="sin function")
ax.plot(x_sin, polynomial_function(x_sin, optimal_w, m), c="red", label=f"model M={m}")
ax.scatter(x, t)
ax.set_xlim((0, 1))
ax.set_ylim((-1.25, 1.25))
ax.set_xlabel('x', fontsize=16)
ax.set_ylabel('t', rotation=0, fontsize=16)
ax.legend()
ax.annotate(text, (0.6, 0.5), fontsize=14)
plot_ridge(x, t, x_sin, t_sin, optimal_w1, m=M, text='lambda = 1e-8', ax=ax1)
plot_ridge(x, t, x_sin, t_sin, optimal_w2, m=M, text='lambda = 1.0', ax=ax2)
plt.show()print(f"coefficients at lambda=1e-8 is {optimal_w1.round(3)}")
print(f"coefficients at lambda=1.0 is {optimal_w2.round(3)}")
# -----print result-----
# coefficients at lambda=1e-8 is [ 0.104 0.223 33.063 -91.357 3.467 90.523 39.05 -58.172 -72.941 55.889]
# coefficients at lambda=1.0 is [ 0.364 0.321 0.074 -0.155 -0.312 -0.409 -0.465 -0.495 -0.507 -0.51 ]ํต๊ณํ์์ ์ด๋ฌํ ํ ํฌ๋์ ์์ถ ๋ฐฉ๋ฒ(shrinkage method) ์ด๋ผ๊ณ ํ๋๋ฐ, ๊ทธ ์ด์ ๋ ๊ณ์์ ๊ฐ์ ์ค์ฌ์ฃผ๊ธฐ ๋๋ฌธ์ด๋ค. ํนํ, ์์์์ ๋์จ ๋ฐฉ๋ฒ์ ridge regression ์ด๋ค. ํฅํ์ ์ด์ผ๊ธฐํ ์ ๊ฒฝ๋ง์์๋ weight decay ๋ผ๊ณ ๋ ํ๋ค. ๊ทธ๋ ๋ค๊ณ ํด์ ์์ฃผ ๋์ ์ ๊ทํ๋ฅผ ํญ์ ๊ฐํ๊ฒ ๊ฐ์ ธ๊ฐ์ผํ๋ ๊ฒ์ ์๋๋ค. ์์ ๊ทธ๋ฆผ์์ ์ ๊ทํ ๊ณ์๊ฐ 1์ธ ๋ชจ๋ธ์ ๊ณผ์์ ํฉ์ ์ผ๊ธฐํ๊ธฐ ๋๋ฌธ์ด๋ค.
์๋ ๊ทธ๋ฆผ์ ์ ๊ทํ ๊ณ์๊ฐ ์ปค์ง์ ๋ฐ๋ผ RMS error ๋ฅผ ๊ตฌํ ๊ฒ์ด๋ค. ์ฐจ์๊ฐ 9 ์์๋ ๋ถ๊ตฌํ๊ณ ๋ฎ์ RMS error ๋ฅผ ์ ์งํ๊ณ ์๋ค.
(์ฑ ์์ ๋์จ ๊ทธ๋ํ๋ ์์ด ํ ์๋ ์๋๋ฐ, ์ด๋ seed ๊ฐ ๋ฌ๋ผ์ ์ฑ ์ ์๋ ๋ฐ์ดํฐ์ ์์ ํ ๊ฐ์ ์๊ฐ ์๊ธฐ ๋๋ฌธ์ด๋ค.)
Last updated
Was this helpful?