2020년 12월 16일 수요일

Numeric Analysis 2 - Least Squares Method

Numeric Analysis 2 - Least Squares Method

Numeric Analysis 2 - Least Squares Method

In [1]:
from math import *
import matplotlib.pyplot as plt
%matplotlib inline 

Least Squares Method

In curve fitting we are given $n$ points $(x_1,y_1), (x_2,y_2), \cdots , (x_n,y_n)$ and we want to determine a function $f(x)$ (for example, polynomials) such that

$$ f(x_1) \approx y_1, \, f(2) \approx y_2 , \, \cdots \, , f(x_n) \approx y_n $$

Method of Least Squares.

The straight line

$$ y = a + b x $$

should be fitted through the given points $(x_1,y_1), (x_2,y_2), \cdots , (x_n,y_n)$ so that the sum of the squares of the distances of those points from the straight line is minimum

The sum of squares is $$ q = \sum _{j=1} ^n \, ( y_j -a -bx_j ) ^2 $$

image.png

A necessary condition for $q$ to be minimum is

image.png

We obtain the result

image.png

Solving for $a$ and $b \,\,\,$ (소거법)

$$ a = \frac {\sum {x_j}^2 \, \sum y_j - \sum x_j \, \sum x_j y_j } {n \sum {x_j}^2 - \left ({\sum x_j} \right)^2 } \qquad \text{(절편)}$$$$ b = \frac { n \sum x_j y_j - \sum x_j \, \sum y_j } {n \sum {x_j}^2 - \left ({\sum x_j} \right)^2 } \qquad \text{(기울기)}$$

[ Example ] Find a straight line for the four points

 (-1.3, 0.103), (-0.1, 1.099), (0.2, 0.808), (1.3, 1.897)
In [2]:
def leastsquare( x, y ) :
    sumx = 0.0
    sumx2= 0.0
    sumy = 0.0
    sumxy= 0.0
    n = len(x)
    
    for j in range(n) :
        sumx  += x[j]
        sumx2 += x[j]**2
        sumy  += y[j]
        sumxy += x[j]*y[j]
        
    a = ( sumx2*sumy - sumx*sumxy )/( n*sumx2 - sumx**2 )
    b = ( n*sumxy - sumx*sumy )/( n*sumx2 - sumx**2 )
    return a, b


Data = [ (-1.3, 0.103), (-0.1, 1.099), (0.2, 0.808), (1.3, 1.897) ]
XY = list( zip( *Data ) )
x = XY[0]
y = XY[1]
print( x, y )

a, b = leastsquare( x, y )

print('a = %9.5f' % a )
print('b = %9.5f' % b )
(-1.3, -0.1, 0.2, 1.3) (0.103, 1.099, 0.808, 1.897)
a =   0.96007
b =   0.66702

Plot the data and the fitted line

In [3]:
import numpy as np
x = np.array( x )
yfit = a + b * x  

plt.plot(x, y, 'o', label='data')
plt.plot(x, yfit, 'r', label='fitted line')
plt.legend()
plt.xlabel('x')
plt.ylabel('y')
plt.show()

Curve Fitting by Polynomials of Degree $m$

Curve fitting can be generalized to a polynomial of degree $m$

$$ p(x) = b_0 + b_1 x + \cdots + b_m x^m $$

the sum of the squares of the errors, $q$ takes the form $$ q = \sum _{j=1} ^n \, ( y_j - p(x_j) ) ^2 $$

The conditions to determine $m+1$ parameters $b_0, \cdots, b_m$ are given by

image.png

In the case of a quadratic polynomial

$$ p(x) = b_0 + b_1 x + b_2 x^2 $$

The equations to determine $b_0, b_1, b_2$ are

image.png

[ Example ] Quadratic Parabola by Least Squares

Fit a parabola through the data (0, 5), (2, 4), (4, 1), (6, 6), (8, 7).

image.png

In [4]:
import numpy as np
from scipy.linalg import *  

def leastsquare2( x, y ) :
    sx  = 0.0
    sx2 = 0.0
    sx3 = 0.0    
    sx4 = 0.0    
    sy  = 0.0
    sxy = 0.0
    sx2y = 0.0
    n = len(x)
    
    for j in range(n) :
        sx  += x[j]
        sx2 += x[j]**2
        sx3 += x[j]**3
        sx4 += x[j]**4
        sy  += y[j]
        sxy += x[j]*y[j]
        sx2y += x[j]**2 *y[j]
        
    A = np.array( [[n, sx, sx2], [sx, sx2, sx3], [sx2, sx3, sx4]] )
    r = np.array( [sy, sxy, sx2y] )      # 열벡터
    b = solve(A, r)
    return b
In [5]:
Data = [ (0, 5), (2, 4), (4, 1), (6, 6), (8, 7) ]
XY = list( zip( *Data ) )
x = XY[0]
y = XY[1]
print( x, y )

b = leastsquare2( x, y )

print( b )
(0, 2, 4, 6, 8) (5, 4, 1, 6, 7)
[ 5.11428571 -1.41428571  0.21428571]

Plot the data and the fitted curve

In [6]:
plt.plot(x, y, 'o', label='data')

# fitted curve
xf = np.linspace( x[0], x[-1], 101 )
yf = b[0] + b[1]*xf + b[2]*xf**2  
plt.plot(xf, yf, 'r', label='fitted curve')

plt.legend()
plt.xlabel('x')
plt.ylabel('y')
plt.show()

댓글 없음:

댓글 쓰기

Numeric Analysis 4 - Numeric Linear Algebra

Numeric Analysis 4 - Numeric Linear Algebra Numeric Linear Algebra ¶ ...