Regularized linear regression

 

import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import Lasso, Ridge, ElasticNet, LinearRegression 

X = np.random.randn(40)
y = np.cos(X) + np.random.rand(40)/3
plt.scatter(X, y)
plt.show()

output_2_0

from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline
def plot_comparison(degree):
    plt.figure(figsize=(9,7))
    models = [("Linear (non-regularized)", LinearRegression()), ("Lasso", Lasso(alpha=0.01)),
              ("Ridge", Ridge(alpha=0.01)), ("Elastic Net", ElasticNet(alpha=0.01, l1_ratio=0.5))]
    for i, (model_name, model) in enumerate(models):
        plt.subplot(2,2,i+1)
        linear_model = make_pipeline(PolynomialFeatures(degree), model)
        linear_model.fit(X.reshape(-1, 1), y)
        plt.scatter(X, y, color="w", edgecolor="k")

        xx = np.linspace(-3, 4, 10000)
        plt.plot(xx, linear_model.predict(xx.reshape(-1, 1)), color="b", linewidth=1)
        plt.xlim(-4, 4)
        plt.ylim(-5, 2)
        plt.title("{} Regression".format(model_name))
    plt.tight_layout()
    plt.show()

plot_comparison(5)

output_4_0

plot_comparison(15)

output_5_0

plot_comparison(30)

output_6_0

XX = np.c_[np.random.randn(40), np.random.randn(40)**2, np.random.randn(40)**3, np.random.rand(40)*3, np.random.rand(40)]
yy = np.cos(X[0]) + np.random.rand(40)/3
alphas, coefs, dual_gaps = Lasso().path(XX, yy, alphas=np.logspace(-4, 1, 8))
plt.figure(figsize=(7,5))
for i in range(5):
    plt.plot(alphas, coefs[i], label=r"X_{}".format(i+1))
plt.axhline(0, color="k", linestyle="--")
plt.semilogx()
plt.title("Lasso path")
plt.legend()
plt.show();

output_7_0

 

답글 남기기

댓글을 게시하려면 다음의 방법 중 하나를 사용하여 로그인 하세요:

WordPress.com 로고

WordPress.com의 계정을 사용하여 댓글을 남깁니다. 로그아웃 /  변경 )

Google photo

Google의 계정을 사용하여 댓글을 남깁니다. 로그아웃 /  변경 )

Twitter 사진

Twitter의 계정을 사용하여 댓글을 남깁니다. 로그아웃 /  변경 )

Facebook 사진

Facebook의 계정을 사용하여 댓글을 남깁니다. 로그아웃 /  변경 )

%s에 연결하는 중

This site uses Akismet to reduce spam. Learn how your comment data is processed.