import pandas as pdimport numpy as npfrom matplotlib import pyplot as pltfrom sklearn.linear_model import LinearRegressionfrom mpl_toolkits.mplot3d import axes3dimport seaborn as snsclass MyRegression:def __init__(self):
pd.set_option("display.notebook_repr_html", False)
pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", None)
pd.set_option("display.max_seq_items", None)
sns.set_context("notebook")
sns.set_style("white")
self.warmUpExercise = np.identity(5)
self.data = np.loadtxt("testSet.txt", delimiter="\t")# 100*2self.x = np.c_[np.ones(self.data.shape[0]), self.data[:, 0]]# 100*1self.y = np.c_[self.data[:, 1]]def data_view(self):
plt.scatter(self.x[:, 1], self.y, s=30, c="r", marker="x", linewidths=1)
plt.xlabel("x軸")
plt.ylabel("y軸")
plt.show()# 計(jì)算損失函數(shù)def compute_cost(self, theta=[[0], [0]]):
m = self.y.size
h = self.x.dot(theta)
J = 1.0 / (2*m) * (np.sum(np.square(h - self.y)))return J# 梯度下降def gradient_descent(self, theta=[[0], [0]], alpha=0.01, num_iters=100):
m = self.y.size
J_history = np.zeros(num_iters)for iters in np.arange(num_iters):
h = self.x.dot(theta)# theta的迭代計(jì)算theta = theta - alpha * (1.0 / m) * (self.x.T.dot(h-self.y))
J_history[iters] = self.compute_cost(theta)return theta, J_historydef result_view1(self):
theta, J_history = self.gradient_descent()
plt.plot(J_history)
plt.ylabel("Cost J")
plt.xlabel("Iterations")
plt.show()def result_view2(self):
theta, J_history = self.gradient_descent()
xx = np.arange(-5, 10)
yy = theta[0] + theta[1] * xx# 畫出我們自己寫的線性回歸梯度下降收斂的情況plt.scatter(self.x[:, 1], self.y, s=30, c="g", marker="x", linewidths=1)
plt.plot(xx, yy, label="Linear Regression (Gradient descent)")# 和Scikit-learn中的線性回歸對(duì)比一下regr = LinearRegression()
regr.fit(self.x[:, 1].reshape(-1, 1), self.y.ravel())
plt.plot(xx, regr.intercept_+regr.coef_*xx, label="Linear Regression (Scikit-learn GLM)")
plt.xlabel("x軸")
plt.ylabel("y軸")
plt.legend(loc=4)
plt.show()if __name__ == '__main__':
my_regression = MyRegression()# my_regression.result_view1()my_regression.result_view2()