Numeric Analysis 2 - Least Squares Method¶
from math import *
import matplotlib.pyplot as plt
%matplotlib inline
Least Squares Method¶
In curve fitting we are given n points (x1,y1),(x2,y2),⋯,(xn,yn) and we want to determine a function f(x) (for example, polynomials) such that
f(x1)≈y1,f(2)≈y2,⋯,f(xn)≈yn
Method of Least Squares.¶
The straight line
y=a+bx
should be fitted through the given points (x1,y1),(x2,y2),⋯,(xn,yn) so that the sum of the squares of the distances of those points from the straight line is minimum
The sum of squares is q=n∑j=1(yj−a−bxj)2
A necessary condition for q to be minimum is
We obtain the result
Solving for a and b (소거법)
a=∑xj2∑yj−∑xj∑xjyjn∑xj2−(∑xj)2(절편)[ Example ] Find a straight line for the four points
(-1.3, 0.103), (-0.1, 1.099), (0.2, 0.808), (1.3, 1.897)
def leastsquare( x, y ) :
sumx = 0.0
sumx2= 0.0
sumy = 0.0
sumxy= 0.0
n = len(x)
for j in range(n) :
sumx += x[j]
sumx2 += x[j]**2
sumy += y[j]
sumxy += x[j]*y[j]
a = ( sumx2*sumy - sumx*sumxy )/( n*sumx2 - sumx**2 )
b = ( n*sumxy - sumx*sumy )/( n*sumx2 - sumx**2 )
return a, b
Data = [ (-1.3, 0.103), (-0.1, 1.099), (0.2, 0.808), (1.3, 1.897) ]
XY = list( zip( *Data ) )
x = XY[0]
y = XY[1]
print( x, y )
a, b = leastsquare( x, y )
print('a = %9.5f' % a )
print('b = %9.5f' % b )
Plot the data and the fitted line
import numpy as np
x = np.array( x )
yfit = a + b * x
plt.plot(x, y, 'o', label='data')
plt.plot(x, yfit, 'r', label='fitted line')
plt.legend()
plt.xlabel('x')
plt.ylabel('y')
plt.show()
Curve Fitting by Polynomials of Degree m¶
Curve fitting can be generalized to a polynomial of degree m
p(x)=b0+b1x+⋯+bmxmthe sum of the squares of the errors, q takes the form q=n∑j=1(yj−p(xj))2
The conditions to determine m+1 parameters b0,⋯,bm are given by
In the case of a quadratic polynomial
p(x)=b0+b1x+b2x2
The equations to determine b0,b1,b2 are
[ Example ] Quadratic Parabola by Least Squares
Fit a parabola through the data (0, 5), (2, 4), (4, 1), (6, 6), (8, 7).
import numpy as np
from scipy.linalg import *
def leastsquare2( x, y ) :
sx = 0.0
sx2 = 0.0
sx3 = 0.0
sx4 = 0.0
sy = 0.0
sxy = 0.0
sx2y = 0.0
n = len(x)
for j in range(n) :
sx += x[j]
sx2 += x[j]**2
sx3 += x[j]**3
sx4 += x[j]**4
sy += y[j]
sxy += x[j]*y[j]
sx2y += x[j]**2 *y[j]
A = np.array( [[n, sx, sx2], [sx, sx2, sx3], [sx2, sx3, sx4]] )
r = np.array( [sy, sxy, sx2y] ) # 열벡터
b = solve(A, r)
return b
Data = [ (0, 5), (2, 4), (4, 1), (6, 6), (8, 7) ]
XY = list( zip( *Data ) )
x = XY[0]
y = XY[1]
print( x, y )
b = leastsquare2( x, y )
print( b )
Plot the data and the fitted curve
plt.plot(x, y, 'o', label='data')
# fitted curve
xf = np.linspace( x[0], x[-1], 101 )
yf = b[0] + b[1]*xf + b[2]*xf**2
plt.plot(xf, yf, 'r', label='fitted curve')
plt.legend()
plt.xlabel('x')
plt.ylabel('y')
plt.show()
댓글 없음:
댓글 쓰기