머신러닝/직접 구현

[Python] Numpy 로 LinearRegression을 구현해보자!!

easysheep 2023. 6. 13. 17:21

1. 코드

import numpy as np
In [122]:
# Make Dummy Data

data_num = 1000

# feature 값 생성
X1 = np.random.randn(data_num)
X2 = np.random.randn(data_num)
X3 = np.random.randn(data_num)

# 노이즈
noise = np.random.normal(0,5,data_num)

# target 값 생성
y = 30 * X1 - 20 * X2 + 12 *X3 +10+ noise


X = np.array([X1,X2,X3]).T

 

이렇게 되는 이유를 자세히 알고 싶으면 아래 사이트를 참고

gradient_weights =  -(1/self.data_num) * np.dot(X.T , error)
gradient_bias =  -(1/self.data_num) * np.sum(error)

출처 :https://computer-nerd.tistory.com/5

In [127]:
# cost function = MSE

class LinearRegressor:
    def __init__(self , epochs = 100 , lr = 0.01):
        # 편향을 0으로 설정
        self.bias = 0

        # loss값을 저장할 리스트
        self.loss_log = []

        self.epochs = epochs
        self.lr = lr

    def fit(self, X , y):
        self.data_num, self.feat_num = X.shape

        # weight 값 초기화
        self.weights = np.ones(self.feat_num)
        
        for n in range(self.epochs):
            # y 값 예측
            y_hat = np.dot(X , self.weights) + self.bias
            
            # error 값 구하기
            error = y - y_hat

            # Mean square Error 구하기 
            mse = np.square(error).mean()
            self.loss_log.append(mse)

            # Gradient Descent 를 이용하여 weight , bias 업데이트 
            
            ## weight , bias 별 gradient 값 구하기
            
            gradient_weights =  -(1/self.data_num) * np.dot(X.T , error)
            gradient_bias =  -(1/self.data_num) * np.sum(error)

            ## weight 값 update
            self.weights = self.weights - self.lr * gradient_weights 
            self.bias = self.bias - self.lr *  gradient_bias

            # 주석 풀고 
            # print(f'epoch : {n+1} \n MSE : {mse}')

    def predict(self , X):
        return np.dot(self.weights , X) + self.bias
    
    def show_loss_log(self):
        for loss in self.loss_log:
            print (loss)
        return self.loss_log
In [128]:
lr=LinearRegressor()
In [129]:
lr.fit(X=X,y=y)
epoch : 1 
 MSE : 1575.316463147643
epoch : 2 
 MSE : 1543.2913120040535
epoch : 3 
 MSE : 1511.9294656241652
epoch : 4 
 MSE : 1481.2171577261122
epoch : 5 
 MSE : 1451.1409082496561
epoch : 6 
 MSE : 1421.6875173957083
epoch : 7 
 MSE : 1392.8440597901556
epoch : 8 
 MSE : 1364.5978787693875
epoch : 9 
 MSE : 1336.9365807849902
epoch : 10 
 MSE : 1309.8480299251128
epoch : 11 
 MSE : 1283.3203425500765
epoch : 12 
 MSE : 1257.3418820398363
epoch : 13 
 MSE : 1231.9012536509636
epoch : 14 
 MSE : 1206.9872994808595
epoch : 15 
 MSE : 1182.5890935369648
epoch : 16 
 MSE : 1158.69593690877
epoch : 17 
 MSE : 1135.2973530404836
epoch : 18 
 MSE : 1112.3830831022537
epoch : 19 
 MSE : 1089.9430814578882
epoch : 20 
 MSE : 1067.9675112270602
epoch : 21 
 MSE : 1046.4467399400194
epoch : 22 
 MSE : 1025.37133528289
epoch : 23 
 MSE : 1004.7320609316547
epoch : 24 
 MSE : 984.5198724729764
epoch : 25 
 MSE : 964.7259134100499
epoch : 26 
 MSE : 945.3415112516994
epoch : 27 
 MSE : 926.3581736829947
epoch : 28 
 MSE : 907.7675848156783
epoch : 29 
 MSE : 889.5616015167384
epoch : 30 
 MSE : 871.7322498135003
epoch : 31 
 MSE : 854.2717213736357
epoch : 32 
 MSE : 837.172370058528
epoch : 33 
 MSE : 820.4267085484644
epoch : 34 
 MSE : 804.0274050381516
epoch : 35 
 MSE : 787.9672800010943
epoch : 36 
 MSE : 772.2393030213922
epoch : 37 
 MSE : 756.8365896915545
epoch : 38 
 MSE : 741.7523985749527
epoch : 39 
 MSE : 726.9801282315606
epoch : 40 
 MSE : 712.5133143056661
epoch : 41 
 MSE : 698.3456266742568
epoch : 42 
 MSE : 684.4708666548166
epoch : 43 
 MSE : 670.8829642712936
epoch : 44 
 MSE : 657.5759755770237
epoch : 45 
 MSE : 644.5440800334255
epoch : 46 
 MSE : 631.7815779432981
epoch : 47 
 MSE : 619.2828879375879
epoch : 48 
 MSE : 607.0425445145055
epoch : 49 
 MSE : 595.0551956299017
epoch : 50 
 MSE : 583.3156003378343
epoch : 51 
 MSE : 571.8186264802778
epoch : 52 
 MSE : 560.559248424951
epoch : 53 
 MSE : 549.5325448502604
epoch : 54 
 MSE : 538.7336965763737
epoch : 55 
 MSE : 528.1579844414657
epoch : 56 
 MSE : 517.8007872221888
epoch : 57 
 MSE : 507.6575795974531
epoch : 58 
 MSE : 497.72393015460506
epoch : 59 
 MSE : 487.9954994371268
epoch : 60 
 MSE : 478.4680380329866
epoch : 61 
 MSE : 469.1373847027952
epoch : 62 
 MSE : 459.99946454693514
epoch : 63 
 MSE : 451.05028721085387
epoch : 64 
 MSE : 442.28594512772236
epoch : 65 
 MSE : 433.7026117976806
epoch : 66 
 MSE : 425.2965401029085
epoch : 67 
 MSE : 417.06406065777315
epoch : 68 
 MSE : 409.00158019332343
epoch : 69 
 MSE : 401.1055799754135
epoch : 70 
 MSE : 393.3726142557565
epoch : 71 
 MSE : 385.7993087552201
epoch : 72 
 MSE : 378.38235917869287
epoch : 73 
 MSE : 371.11852976086266
epoch : 74 
 MSE : 364.0046518422633
epoch : 75 
 MSE : 357.0376224749582
epoch : 76 
 MSE : 350.21440305724286
epoch : 77 
 MSE : 343.5320179967629
epoch : 78 
 MSE : 336.9875534014527
epoch : 79 
 MSE : 330.57815579771807
epoch : 80 
 MSE : 324.3010308752924
epoch : 81 
 MSE : 318.1534422582114
epoch : 82 
 MSE : 312.13271030136275
epoch : 83 
 MSE : 306.236210912077
epoch : 84 
 MSE : 300.4613743962373
epoch : 85 
 MSE : 294.8056843283988
epoch : 86 
 MSE : 289.266676445415
epoch : 87 
 MSE : 283.8419375630837
epoch : 88 
 MSE : 278.5291045153312
epoch : 89 
 MSE : 273.32586311546544
epoch : 90 
 MSE : 268.2299471390386
epoch : 91 
 MSE : 263.23913732786923
epoch : 92 
 MSE : 258.3512604147824
epoch : 93 
 MSE : 253.56418816863592
epoch : 94 
 MSE : 248.8758364592113
epoch : 95 
 MSE : 244.28416434155443
epoch : 96 
 MSE : 239.7871731593616
epoch : 97 
 MSE : 235.3829056670132
epoch : 98 
 MSE : 231.06944516986772
epoch : 99 
 MSE : 226.84491468243488
epoch : 100 
 MSE : 222.70747610405616
In [130]:
## loss값 변화량

import matplotlib.pyplot as plt
plt.plot(range(1,101) ,lr.loss_log)
Out[130]:
[<matplotlib.lines.Line2D at 0x168de8940>]