44. 简单线性回归模型#

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

FONTPATH = "fonts/SourceHanSerifSC-SemiBold.otf"
mpl.font_manager.fontManager.addfont(FONTPATH)
plt.rcParams['font.family'] = ['Source Han Serif SC']

简单回归模型估计两个变量 \(x_i\)\(y_i\) 之间的关系

\[ y_i = \alpha + \beta x_i + \epsilon_i, i = 1,2,...,N \]

其中 \(\epsilon_i\) 表示最佳拟合线与样本值 \(y_i\)\(x_i\) 的误差。

我们的目标是为 \(\alpha\)\(\beta\) 选择值来为一些可用的变量 \(x_i\)\(y_i\) 的数据构建“最佳”拟合线。

让我们考虑一个具有10个观察值的简单数据集,变量为 \(x_i\)\(y_i\)

\(y_i\)

\(x_i\)

1

2000

32

2

1000

21

3

1500

24

4

2500

35

5

500

10

6

900

11

7

1100

22

8

1500

21

9

1800

27

10

250

2

让我们把 \(y_i\) 视为一个冰淇淋车的销售额,而 \(x_i\) 是记录当天摄氏度温度的变量。

x = [32, 21, 24, 35, 10, 11, 22, 21, 27, 2]
y = [2000,1000,1500,2500,500,900,1100,1500,1800, 250]
df = pd.DataFrame([x,y]).T
df.columns = ['X', 'Y']
df
X Y
0 32 2000
1 21 1000
2 24 1500
3 35 2500
4 10 500
5 11 900
6 22 1100
7 21 1500
8 27 1800
9 2 250

我们可以通过数据的散点图来观察 \(y_i\)(冰淇淋销售额(美元($’s))和 \(x_i\)(摄氏度)之间的关系。

ax = df.plot(
    x='X', 
    y='Y', 
    kind='scatter', 
    ylabel='冰淇淋销售额(\$)', 
    xlabel='摄氏度'
)
<>:5: SyntaxWarning: invalid escape sequence '\$'
<>:5: SyntaxWarning: invalid escape sequence '\$'
/tmp/ipykernel_6216/734624122.py:5: SyntaxWarning: invalid escape sequence '\$'
  ylabel='冰淇淋销售额(\$)',
_images/e5e93fcd510171f7fbd4950fc23c22909a71827cd0a26ea1bf481a3c8eecb4cd.png

Fig. 44.1 散点图#

如您所见,数据表明在更热的日子里通常会卖出更多的冰淇淋。

为了建立数据的线性模型,我们需要选择代表“最佳”拟合线的 \(\alpha\)\(\beta\) 值,使得

\[ \hat{y_i} = \hat{\alpha} + \hat{\beta} x_i \]

让我们从 \(\alpha = 5\)\(\beta = 10\) 开始

α = 5
β = 10
df['Y_hat'] = α + β * df['X']
fig, ax = plt.subplots()
ax = df.plot(x='X',y='Y', kind='scatter', ax=ax)
ax = df.plot(x='X',y='Y_hat', kind='line', ax=ax, label='$\hat Y$')
plt.show()
<>:3: SyntaxWarning: invalid escape sequence '\h'
<>:3: SyntaxWarning: invalid escape sequence '\h'
/tmp/ipykernel_6216/2186302095.py:3: SyntaxWarning: invalid escape sequence '\h'
  ax = df.plot(x='X',y='Y_hat', kind='line', ax=ax, label='$\hat Y$')
_images/468063cda355bbddd56f1b13ddc1a67f28c47f78cf8cc8dfb85b8029108263a7.png

Fig. 44.2 带有拟合线的散点图#

我们可以看到这个模型在估计关系上做得很差。

我们可以继续通过调整参数来试图迭代并逼近“最佳”拟合线。

β = 100
df['Y_hat'] = α + β * df['X']
fig, ax = plt.subplots()
ax = df.plot(x='X',y='Y', kind='scatter', ax=ax)
ax = df.plot(x='X',y='Y_hat', kind='line', ax=ax, label='$\hat Y$')
plt.show()
<>:3: SyntaxWarning: invalid escape sequence '\h'
<>:3: SyntaxWarning: invalid escape sequence '\h'
/tmp/ipykernel_6216/2186302095.py:3: SyntaxWarning: invalid escape sequence '\h'
  ax = df.plot(x='X',y='Y_hat', kind='line', ax=ax, label='$\hat Y$')
_images/24a9ea5ecff2e64e5e4b2990e38578a7a2b9ec6050f665a73b5673d0e8da99cf.png

Fig. 44.3 带拟合线的散点图 #2#

β = 65
df['Y_hat'] = α + β * df['X']
fig, ax = plt.subplots()
ax = df.plot(x='X',y='Y', kind='scatter', ax=ax)
yax = df.plot(x='X',y='Y_hat', kind='line', ax=ax, color='g', label='$\hat Y$')
plt.show()
<>:3: SyntaxWarning: invalid escape sequence '\h'
<>:3: SyntaxWarning: invalid escape sequence '\h'
/tmp/ipykernel_6216/41840720.py:3: SyntaxWarning: invalid escape sequence '\h'
  yax = df.plot(x='X',y='Y_hat', kind='line', ax=ax, color='g', label='$\hat Y$')
_images/f63b5c3ece86f59b82c8b04503b9cfdb737e92d07c925a576570bbcc7004eef1.png

Fig. 44.4 带拟合线的散点图 #3#

但是我们需要考虑将这个猜测过程正式化,把这个问题看作是一个优化问题。

让我们考虑误差 \(\epsilon_i\) 并定义观测值 \(y_i\) 与估计值 \(\hat{y}_i\) 之间的差异,我们将其称为残差

\[\begin{split} \begin{aligned} \hat{e}_i &= y_i - \hat{y}_i \\ &= y_i - \hat{\alpha} - \hat{\beta} x_i \end{aligned} \end{split}\]
df['error'] = df['Y_hat'] - df['Y']
df
X Y Y_hat error
0 32 2000 2085 85
1 21 1000 1370 370
2 24 1500 1565 65
3 35 2500 2280 -220
4 10 500 655 155
5 11 900 720 -180
6 22 1100 1435 335
7 21 1500 1370 -130
8 27 1800 1760 -40
9 2 250 135 -115
fig, ax = plt.subplots()
ax = df.plot(x='X',y='Y', kind='scatter', ax=ax)
yax = df.plot(x='X',y='Y_hat', kind='line', ax=ax, color='g', label='$\hat Y$')
plt.vlines(df['X'], df['Y_hat'], df['Y'], color='r')
plt.show()
<>:3: SyntaxWarning: invalid escape sequence '\h'
<>:3: SyntaxWarning: invalid escape sequence '\h'
/tmp/ipykernel_6216/2070929280.py:3: SyntaxWarning: invalid escape sequence '\h'
  yax = df.plot(x='X',y='Y_hat', kind='line', ax=ax, color='g', label='$\hat Y$')
_images/9b40c1c9796b8cd275ca8bcee90a1dc046c31ef45605f0cd320d90a37b9122f0.png

Fig. 44.5 残差图#

普通最小二乘方法 (OLS) 选择 \(\alpha\)\(\beta\),以使残差平方和 (SSR) 最小化

\[ \min_{\alpha,\beta} \sum_{i=1}^{N}{\hat{e}_i^2} = \min_{\alpha,\beta} \sum_{i=1}^{N}{(y_i - \alpha - \beta x_i)^2} \]

我们称之为成本函数

\[ C = \sum_{i=1}^{N}{(y_i - \alpha - \beta x_i)^2} \]

我们希望通过参数 \(\alpha\)\(\beta\) 来最小化这个成本函数。

44.1. 残差相对于 \(\alpha\)\(\beta\) 的变化#

首先让我们看看总误差相对于 \(\beta\) 的变化(保持截距 \(\alpha\) 不变)

我们从下一节知道 \(\alpha\)\(\beta\) 的最优值是:

β_optimal = 64.38
α_optimal = -14.72

我们可以计算一个范围内的 \(\beta\) 值的残差

errors = {}
for β in np.arange(20,100,0.5):
    errors[β] = abs((α_optimal + β * df['X']) - df['Y']).sum()

绘制残差图

ax = pd.Series(errors).plot(xlabel='β', ylabel='残差')
plt.axvline(β_optimal, color='r');
_images/233a641767a74e980b7239f1c74007309d7db6a2ef1ca9855e667efdeb374112.png

Fig. 44.6 绘制残差图#

现在我们改变 \(\alpha\) (保持 \(\beta\) 不变)

errors = {}
for α in np.arange(-500,500,5):
    errors[α] = abs((α + β_optimal * df['X']) - df['Y']).sum()

绘制残差图

ax = pd.Series(errors).plot(xlabel='α', ylabel='残差')
plt.axvline(α_optimal, color='r');
_images/edd915afb3bf4d63fe48fe30a95a3935a7df289c89b10f6a0d385c282d367870.png

Fig. 44.7 绘制残差图 (2)#

44.2. 计算最优值#

现在让我们使用微积分来解决优化问题,并计算出 \(\alpha\)\(\beta\) 的最优值,以找到普通最小二乘解。

首先对 \(\alpha\) 取偏导

\[ \frac{\partial C}{\partial \alpha}[\sum_{i=1}^{N}{(y_i - \alpha - \beta x_i)^2}] \]

并将其设为 \(0\)

\[ 0 = \sum_{i=1}^{N}{-2(y_i - \alpha - \beta x_i)} \]

我们可以通过两边除以 \(-2\) 来移除求和中的常数 \(-2\)

\[ 0 = \sum_{i=1}^{N}{(y_i - \alpha - \beta x_i)} \]

现在我们可以将这个方程分解为各个组成部分

\[ 0 = \sum_{i=1}^{N}{y_i} - \sum_{i=1}^{N}{\alpha} - \beta \sum_{i=1}^{N}{x_i} \]

中间项是从 \(i=1,...N\) 对常数 \(\alpha\) 进行简单求和

\[ 0 = \sum_{i=1}^{N}{y_i} - N*\alpha - \beta \sum_{i=1}^{N}{x_i} \]

并重新排列各项

\[ \alpha = \frac{\sum_{i=1}^{N}{y_i} - \beta \sum_{i=1}^{N}{x_i}}{N} \]

我们观察到两个分数分别归结为均值 \(\bar{y_i}\)\(\bar{x_i}\)

(44.1)#\[ \alpha = \bar{y_i} - \beta\bar{x_i} \]

现在让我们对成本函数 \(C\) 关于 \(\beta\) 取偏导

\[ \frac{\partial C}{\partial \beta}[\sum_{i=1}^{N}{(y_i - \alpha - \beta x_i)^2}] \]

并将其设为 \(0\)

\[ 0 = \sum_{i=1}^{N}{-2 x_i (y_i - \alpha - \beta x_i)} \]

我们可以再次将常数从求和中取出,并将两边除以 \(-2\)

\[ 0 = \sum_{i=1}^{N}{x_i (y_i - \alpha - \beta x_i)} \]

这变成了

\[ 0 = \sum_{i=1}^{N}{(x_i y_i - \alpha x_i - \beta x_i^2)} \]

现在代入 \(\alpha\)

\[ 0 = \sum_{i=1}^{N}{(x_i y_i - (\bar{y_i} - \beta \bar{x_i}) x_i - \beta x_i^2)} \]

并重新排列各项

\[ 0 = \sum_{i=1}^{N}{(x_i y_i - \bar{y_i} x_i - \beta \bar{x_i} x_i - \beta x_i^2)} \]

这可以被分成两个求和

\[ 0 = \sum_{i=1}^{N}(x_i y_i - \bar{y_i} x_i) + \beta \sum_{i=1}^{N}(\bar{x_i} x_i - x_i^2) \]

\(\beta\)得到

(44.2)#\[ \beta = \frac{\sum_{i=1}^{N}(x_i y_i - \bar{y_i} x_i)}{\sum_{i=1}^{N}(x_i^2 - \bar{x_i} x_i)} \]

我们现在可以使用(44.1)(44.2) 来计算\(\alpha\)\(\beta\)的最优值

计算\(\beta\)

df = df[['X','Y']].copy()  # 原始数据

# 计算样本均值
x_bar = df['X'].mean()
y_bar = df['Y'].mean()

现在计算10个观察值,然后求和分子和分母

# 计算求和
df['num'] = df['X'] * df['Y'] - y_bar * df['X']
df['den'] = pow(df['X'],2) - x_bar * df['X']
β = df['num'].sum() / df['den'].sum()
print(β)
64.37665782493369

计算\(\alpha\)

α = y_bar - β * x_bar
print(α)
-14.72148541114052

现在我们可以绘制OLS解决方案

df['Y_hat'] = α + β * df['X']
df['error'] = df['Y_hat'] - df['Y']

fig, ax = plt.subplots()
ax = df.plot(x='X',y='Y', kind='scatter', ax=ax)
yax = df.plot(x='X',y='Y_hat', kind='line', ax=ax, color='g', label='$\hat Y$')
plt.vlines(df['X'], df['Y_hat'], df['Y'], color='r');
<>:6: SyntaxWarning: invalid escape sequence '\h'
<>:6: SyntaxWarning: invalid escape sequence '\h'
/tmp/ipykernel_6216/3282581860.py:6: SyntaxWarning: invalid escape sequence '\h'
  yax = df.plot(x='X',y='Y_hat', kind='line', ax=ax, color='g', label='$\hat Y$')
_images/b354f127d932a904b93d5c2bd44d6524ced7770ad67006ed938a340d65f0f721.png

Fig. 44.8 OLS最佳拟合线#

Exercise 44.1

现在您已经知道了使用OLS解决简单线性回归模型的方程,您可以开始运行自己的回归以构建\(y\)\(x\)之间的模型了。

让我们考虑两个经济变量,人均GDP和预期寿命。

  1. 你认为它们之间的关系会是怎样的?

  2. 我们的世界数据中搜集一些数据

  3. 使用pandas导入csv格式的数据,并绘制几个不同国家的图表

  4. 使用(44.1)(44.2)计算\(\alpha\)\(\beta\)的最优值

  5. 使用OLS绘制最佳拟合线

  6. 解释系数并写出人均GDP和预期寿命之间关系的总结句子

Solution to Exercise 44.1

Q2: 搜集一些数据 来自我们的世界数据

如果你遇到困难,可以从这里下载数据副本

Q3: 使用pandas导入csv格式的数据并绘制几个不同国家的兴趣图表

data_url = "https://github.com/QuantEcon/lecture-python-intro/raw/main/lectures/_static/lecture_specific/simple_linear_regression/life-expectancy-vs-gdp-per-capita.csv"
df = pd.read_csv(data_url, nrows=10)
df
Entity Code Year Life expectancy at birth (historical) GDP per capita 417485-annotations Population (historical estimates) Continent
0 Abkhazia OWID_ABK 2015 NaN NaN NaN NaN Asia
1 Afghanistan AFG 1950 27.7 1156.0 NaN 7480464.0 NaN
2 Afghanistan AFG 1951 28.0 1170.0 NaN 7571542.0 NaN
3 Afghanistan AFG 1952 28.4 1189.0 NaN 7667534.0 NaN
4 Afghanistan AFG 1953 28.9 1240.0 NaN 7764549.0 NaN
5 Afghanistan AFG 1954 29.2 1245.0 NaN 7864289.0 NaN
6 Afghanistan AFG 1955 29.9 1246.0 NaN 7971933.0 NaN
7 Afghanistan AFG 1956 30.4 1278.0 NaN 8087730.0 NaN
8 Afghanistan AFG 1957 30.9 1253.0 NaN 8210207.0 NaN
9 Afghanistan AFG 1958 31.5 1298.0 NaN 8333827.0 NaN

您可以看到从我们的世界数据下载的数据为全球各国提供了人均GDP和预期寿命数据。

首先从csv文件中导入几行数据以了解其结构,以便您可以选择要读取到DataFrame中的列,这通常是一个好主意。

您可以观察到有许多我们不需要导入的列,比如Continent

那么我们来构建一个我们想要导入的列的列表

cols = ['Code', 'Year', 'Life expectancy at birth (historical)', 'GDP per capita']
df = pd.read_csv(data_url, usecols=cols)
df
Code Year Life expectancy at birth (historical) GDP per capita
0 OWID_ABK 2015 NaN NaN
1 AFG 1950 27.7 1156.0
2 AFG 1951 28.0 1170.0
3 AFG 1952 28.4 1189.0
4 AFG 1953 28.9 1240.0
... ... ... ... ...
62151 ZWE 1946 NaN NaN
62152 ZWE 1947 NaN NaN
62153 ZWE 1948 NaN NaN
62154 ZWE 1949 NaN NaN
62155 ALA 2015 NaN NaN

62156 rows × 4 columns

有时候重命名列名可以使得在DataFrame中更容易操作

df.columns = ["cntry", "year", "life_expectancy", "gdppc"]
df
cntry year life_expectancy gdppc
0 OWID_ABK 2015 NaN NaN
1 AFG 1950 27.7 1156.0
2 AFG 1951 28.0 1170.0
3 AFG 1952 28.4 1189.0
4 AFG 1953 28.9 1240.0
... ... ... ... ...
62151 ZWE 1946 NaN NaN
62152 ZWE 1947 NaN NaN
62153 ZWE 1948 NaN NaN
62154 ZWE 1949 NaN NaN
62155 ALA 2015 NaN NaN

62156 rows × 4 columns

我们可以看到存在NaN值,这表示缺失数据,所以让我们继续删除这些数据

df.dropna(inplace=True)
df
cntry year life_expectancy gdppc
1 AFG 1950 27.7 1156.0000
2 AFG 1951 28.0 1170.0000
3 AFG 1952 28.4 1189.0000
4 AFG 1953 28.9 1240.0000
5 AFG 1954 29.2 1245.0000
... ... ... ... ...
61960 ZWE 2014 58.8 1594.0000
61961 ZWE 2015 59.6 1560.0000
61962 ZWE 2016 60.3 1534.0000
61963 ZWE 2017 60.7 1582.3662
61964 ZWE 2018 61.4 1611.4052

12445 rows × 4 columns

我们现在已经将我们的DataFrame的行数从62156减少到12445,删除了很多空的数据关系。

现在我们有一个包含一系列年份的人均寿命和人均GDP的数据集。

花点时间了解你实际拥有的数据总是一个好主意。

例如,您可能想要探索这些数据,看看是否所有国家在各年之间的报告都是一致的。

让我们首先看看寿命数据

le_years = df[['cntry', 'year', 'life_expectancy']].set_index(['cntry', 'year']).unstack()['life_expectancy']
le_years
year 1543 1548 1553 1558 1563 1568 1573 1578 1583 1588 ... 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018
cntry
AFG NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... 60.4 60.9 61.4 61.9 62.4 62.5 62.7 63.1 63.0 63.1
AGO NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... 55.8 56.7 57.6 58.6 59.3 60.0 60.7 61.1 61.7 62.1
ALB NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... 77.8 77.9 78.1 78.1 78.1 78.4 78.6 78.9 79.0 79.2
ARE NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... 78.0 78.3 78.5 78.7 78.9 79.0 79.2 79.3 79.5 79.6
ARG NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... 75.9 75.7 76.1 76.5 76.5 76.8 76.8 76.3 76.8 77.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
VNM NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... 73.5 73.5 73.7 73.7 73.8 73.9 73.9 73.9 74.0 74.0
YEM NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... 67.2 67.3 67.4 67.3 67.5 67.4 65.9 66.1 66.0 64.6
ZAF NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... 57.4 58.9 60.7 61.8 62.5 63.4 63.9 64.7 65.4 65.7
ZMB NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... 55.3 56.8 57.8 58.9 59.9 60.7 61.2 61.8 62.1 62.3
ZWE NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... 48.1 50.7 53.3 55.6 57.5 58.8 59.6 60.3 60.7 61.4

166 rows × 310 columns

如您所见,有很多国家在1543年的数据是不可用的!

哪个国家报告了这些数据?

le_years[~le_years[1543].isna()]
year 1543 1548 1553 1558 1563 1568 1573 1578 1583 1588 ... 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018
cntry
GBR 33.94 38.82 39.59 22.38 36.66 39.67 41.06 41.56 42.7 37.05 ... 80.2 80.4 80.8 80.9 80.9 81.2 80.9 81.1 81.2 81.1

1 rows × 310 columns

您可以看到,只有大不列颠(GBR)是可用的

您还可以更仔细地观察时间序列,发现即使对于GBR,它也是不连续的。

le_years.loc['GBR'].plot()
<Axes: xlabel='year'>
_images/e29efce71435159f40e3a8cb843b6fd3e515accb141acb5477d55270be444e57.png

实际上我们可以使用pandas快速检查每个年份涵盖了多少个国家

le_years.stack().unstack(level=0).count(axis=1).plot(xlabel="Year", ylabel="Number of countries");
_images/6c1fb4b4100bd37e47b56b574bc053c2574985cfe1679bed24a60825867f4445.png

所以很明显,如果你进行横断面比较,那么最近的数据将包括更广泛的国家集合

现在让我们考虑数据集中最近的一年2018

df = df[df.year == 2018].reset_index(drop=True).copy()
df.plot(x='gdppc', y='life_expectancy', kind='scatter', xlabel="GDP per capita", ylabel="Life expectancy (years)",);
_images/38aab775680b5f5afcc9ef436ddae6b122b7850f664ed48ce88d6397f3606a30.png

这些数据显示了一些有趣的关系。

  1. 许多国家的人均GDP相近,但寿命差别很大

  2. 人均GDP与预期寿命之间似乎存在正向关系。人均GDP较高的国家往往拥有更高的预期寿命

尽管普通最小二乘法(OLS)是用来解线性方程的,但我们可以通过对变量进行转换(例如对数变换),然后使用OLS来估计转换后的变量。

通过指定 logx 你可以在对数尺度上绘制人均GDP数据

df.plot(x='gdppc', y='life_expectancy', kind='scatter',  xlabel="人均GDP", ylabel="预期寿命(年)", logx=True);
_images/2afe42f3eff5cd374569030663c50ded6aa924aa677b07cd6e36dd48913f3f49.png

从这次转换可以看出,线性模型更贴近数据的形状。

df['log_gdppc'] = df['gdppc'].apply(np.log10)
df
cntry year life_expectancy gdppc log_gdppc
0 AFG 2018 63.1 1934.5550 3.286581
1 ALB 2018 79.2 11104.1660 4.045486
2 DZA 2018 76.1 14228.0250 4.153145
3 AGO 2018 62.1 7771.4420 3.890502
4 ARG 2018 77.0 18556.3830 4.268493
... ... ... ... ... ...
161 VNM 2018 74.0 6814.1420 3.833411
162 OWID_WRL 2018 72.6 15212.4150 4.182198
163 YEM 2018 64.6 2284.8900 3.358865
164 ZMB 2018 62.3 3534.0337 3.548271
165 ZWE 2018 61.4 1611.4052 3.207205

166 rows × 5 columns

Q4: 使用 (44.1)(44.2) 来计算 \(\alpha\)\(\beta\) 的最优值

data = df[['log_gdppc', 'life_expectancy']].copy()  # 从DataFrame中提取数据

# 计算样本均值
x_bar = data['log_gdppc'].mean()
y_bar = data['life_expectancy'].mean()
data
log_gdppc life_expectancy
0 3.286581 63.1
1 4.045486 79.2
2 4.153145 76.1
3 3.890502 62.1
4 4.268493 77.0
... ... ...
161 3.833411 74.0
162 4.182198 72.6
163 3.358865 64.6
164 3.548271 62.3
165 3.207205 61.4

166 rows × 2 columns

# 计算求和
data['num'] = data['log_gdppc'] * data['life_expectancy'] - y_bar * data['log_gdppc']
data['den'] = pow(data['log_gdppc'],2) - x_bar * data['log_gdppc']
β = data['num'].sum() / data['den'].sum()
print(β)
12.643730292819708
α = y_bar - β * x_bar
print(α)
21.70209670138904

Q5: 绘制使用 OLS 找到的最佳拟合线

data['life_expectancy_hat'] = α + β * df['log_gdppc']
data['error'] = data['life_expectancy_hat'] - data['life_expectancy']

fig, ax = plt.subplots()
data.plot(x='log_gdppc',y='life_expectancy', kind='scatter', ax=ax)
data.plot(x='log_gdppc',y='life_expectancy_hat', kind='line', ax=ax, color='g')
plt.vlines(data['log_gdppc'], data['life_expectancy_hat'], data['life_expectancy'], color='r')
<matplotlib.collections.LineCollection at 0x7f9e9e656240>
_images/3ea9e13da91c98b81b8fcdd4e0e9c6a8be3cbe88f4e43f770429cda0ca753674.png

Exercise 44.2

通过最小化平方和并不是生成最佳拟合线的唯一方法。

例如,我们还可以考虑最小化绝对值之和,这样对异常值的权重会更小。

求解 \(\alpha\)\(\beta\) 使用最小绝对值法