13. Numba#

除了 Anaconda 中已有的库之外,本讲座还需要以下库:

!pip install quantecon

Hide code cell output

Requirement already satisfied: quantecon in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (0.11.2)
Requirement already satisfied: numba>=0.49.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (0.62.1)
Requirement already satisfied: numpy>=1.17.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (2.3.5)
Requirement already satisfied: requests in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (2.32.5)
Requirement already satisfied: scipy>=1.5.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (1.16.3)
Requirement already satisfied: sympy in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (1.14.0)
Requirement already satisfied: llvmlite<0.46,>=0.45.0dev0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from numba>=0.49.0->quantecon) (0.45.1)
Requirement already satisfied: charset_normalizer<4,>=2 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.11)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2025.11.12)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from sympy->quantecon) (1.3.0)

请同时确保您安装了最新版本的 Anaconda,因为旧版本是常见错误来源

让我们从一些导入开始:

import numpy as np
import quantecon as qe
import matplotlib.pyplot as plt
import matplotlib as mpl  # i18n
import matplotlib.font_manager  # i18n
FONTPATH = "_fonts/SourceHanSerifSC-SemiBold.otf"  # i18n
mpl.font_manager.fontManager.addfont(FONTPATH)  # i18n
mpl.rcParams['font.family'] = ['Source Han Serif SC']  # i18n

13.1. 概述#

之前的讲座 中,我们学习了向量化,这是一种通过将数组处理操作批量发送到高效底层代码来提高执行速度的方法。

然而,正如 之前所讨论的,传统的向量化方案有以下弱点:

  • 对于复合数组操作,内存消耗极大

  • 对于某些算法,向量化无效甚至不可能实现

绕过这些问题的一种方法是使用 Numba,这是一个面向 Python 的即时(JIT)编译器

Numba 在运行时将函数编译为本地机器码指令。

编译成功后,其性能可与编译后的 C 或 Fortran 媲美。

此外,Numba 还可以完成有用的技巧,例如 多线程

本讲座将介绍核心思路。

Note

一些读者可能对 Numba 与 Julia 之间的关系感到好奇,Julia 包含其自己的 JIT 编译器。虽然这两种编译器在许多方面相似,但 Numba 的目标更为有限,仅尝试编译 Python 语言的一个小子集。虽然这听起来像是一个缺陷,但也是一种优势:Numba 更具限制性的特性使其易于使用,并且非常擅长其所做的事情。

13.3. 注意事项#

Numba 相对容易使用,但并非总是无缝衔接的。

让我们来回顾一些用户常遇到的问题。

13.3.1. 类型推断#

成功的类型推断是 JIT 编译的关键。

在理想情况下,Numba 可以推断出所有必要的类型信息。

当 Numba 无法 推断所有类型信息时,它将抛出错误。

例如,在以下情况中,Numba 在编译 iterate 时无法确定函数 g 的类型:

@jit
def iterate(f, x0, n):
    x = x0
    for t in range(n):
        x = f(x)
    return x

# 未经 jit 编译
def g(x):
    return np.cos(x) - 2 * np.sin(x)

# 这段代码会抛出错误
try:
    iterate(g, 0.5, 100)
except Exception as e:
    print(e)
Failed in nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
During: typing of argument at /tmp/ipykernel_2930/4185348615.py (1)

File "../../../../../../tmp/ipykernel_2930/4185348615.py", line 1:
<source missing, REPL/exec in use?>

During: Pass nopython_type_inference 

This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class 'function'>

我们可以通过编译 g 来轻松修复这个错误。

@jit
def g(x):
    return np.cos(x) - 2 * np.sin(x)

iterate(g, 0.5, 100)
2.223875299559663

在其他情况下,例如当我们想使用来自外部库(如 SciPy)的函数时,可能没有简单的解决方法。

13.3.2. 全局变量#

使用 Numba 时另一个需要注意的问题是全局变量的处理。

例如,考虑以下代码:

a = 1

@jit
def add_a(x):
    return a + x

print(add_a(10))
11
a = 2

print(add_a(10))
11

注意,更改全局变量对函数返回的值没有任何影响 😱。

当 Numba 为函数编译机器码时,它将全局变量视为常量以确保类型稳定性。

为了避免这种情况,请将值作为函数参数传递,而不是依赖全局变量。

13.4. Numba 中的多线程循环#

除了 JIT 编译之外,Numba 还为 CPU 和 GPU 上的并行计算提供支持。

Numba 中 CPU 并行化的关键工具是 prange 函数,它告诉 Numba 在可用的 CPU 核心上并行执行循环迭代。

为了说明,让我们首先看一个简单的单线程(即非并行化)代码片段。

该代码通过以下规则模拟家庭财富 \(w_t\) 的更新

\[ w_{t+1} = R_{t+1} s w_t + y_{t+1} \]

其中

  • \(R\) 是资产的总回报率

  • \(s\) 是家庭的储蓄率,以及

  • \(y\) 是劳动收入。

我们将 \(R\)\(y\) 均建模为来自对数正态分布的独立抽样。

以下是代码:

@jit
def update(w, r=0.1, s=0.3, v1=0.1, v2=1.0):
    " Updates household wealth. "
    # Draw shocks
    R = np.exp(v1 * np.random.randn()) * (1 + r)
    y = np.exp(v2 * np.random.randn())
    # Update wealth
    w = R * s * w + y
    return w

让我们看看在此规则下财富如何演变。

fig, ax = plt.subplots()

T = 100
w = np.empty(T)
w[0] = 5
for t in range(T-1):
    w[t+1] = update(w[t])

ax.plot(w)
ax.set_xlabel('$t$', fontsize=12)
ax.set_ylabel('$w_{t}$', fontsize=12)
plt.show()
_images/ea15328b1e60b574f08b466dfc3fdca32cd240d82a4a2d3ca26f6fcd2d55f955.png

现在,假设我们有一个庞大的家庭群体,并且想知道中位财富将是多少。

这个问题很难用纸笔求解,因此我们将使用模拟:

  1. 向前模拟大量家庭

  2. 计算中位财富

以下是代码:

@jit
def compute_long_run_median(w0=1, T=1000, num_reps=50_000):
    obs = np.empty(num_reps)
    # For each household
    for i in range(num_reps):
        # Set the initial condition and run forward in time
        w = w0
        for t in range(T):
            w = update(w)
        # Record the final value
        obs[i] = w
    # Take the median of all final values
    return np.median(obs)

让我们看看运行速度:

with qe.Timer():
    # Warm up
    compute_long_run_median()
5.7481 seconds elapsed
with qe.Timer():
    # Second run
    compute_long_run_median()
4.8730 seconds elapsed

为了加速这个过程,我们将通过多线程对其进行并行化。

为此,我们添加 parallel=True 标志并将 range 更改为 prange

from numba import prange

@jit(parallel=True)
def compute_long_run_median_parallel(
        w0=1, T=1000, num_reps=50_000
    ):
    obs = np.empty(num_reps)
    for i in prange(num_reps):  # Parallelize over households
        w = w0
        for t in range(T):
            w = update(w)
        obs[i] = w
    return np.median(obs)

让我们看看计时结果:

with qe.Timer():
    # Warm up
    compute_long_run_median_parallel()
1.3308 seconds elapsed
with qe.Timer():
    # Second run
    compute_long_run_median_parallel()
0.9289 seconds elapsed

速度提升非常显著。

注意,我们是跨家庭进行并行化,而非跨时间——单个家庭跨时期的更新本质上是顺序的。

关于基于 GPU 的并行化,请参阅我们关于 JAX 的讲座

13.5. 练习#

Exercise 13.1

之前 我们考虑了如何用蒙特卡洛方法近似 \(\pi\)

在这里使用相同的思路,但使用 Numba 使代码高效。

当样本量较大时,比较有无 Numba 的速度。

Exercise 13.2

Python 定量经济学入门 讲座系列中,您可以学习到关于有限状态马尔可夫链的所有知识。

现在,让我们专注于模拟一个非常简单的此类链的示例。

假设一种资产的回报波动率可以处于两种状态之一——高或低。

跨状态的转移概率如下所示

_images/nfs_ex1.png

例如,设周期长度为一天,假设当前状态为高。

从图中我们可以看出,明天的状态将是:

  • 以 0.8 的概率为高

  • 以 0.2 的概率为低

您的任务是根据此规则模拟每日波动率状态序列。

将序列长度设为 n = 1_000_000,并从高状态开始。

实现一个纯 Python 版本和一个 Numba 版本,并比较速度。

为了测试您的代码,评估链停留在低状态的时间比例。

如果您的代码正确,该比例应约为 2/3。

Exercise 13.3

之前的练习 中,我们使用 Numba 加速了通过蒙特卡洛方法计算常数 \(\pi\) 的工作。

现在尝试添加并行化,看看是否能获得进一步的速度提升。

这里您不应该期望获得巨大的提升,因为虽然有许多独立的任务(抽取点并测试是否在圆内),但每个任务的执行时间都很短。

一般来说,当要并行化的各个任务相对于总执行时间非常小时,并行化效果较差。

这是由于将所有这些小任务分散到多个 CPU 上所带来的开销。

尽管如此,使用合适的硬件,在本练习中仍然可以获得不可忽视的速度提升。

对于蒙特卡洛模拟的规模,请使用一个较大的值,例如 n = 100_000_000

Exercise 13.4

我们关于 SciPy 的讲座 中,我们讨论了在标的股票价格具有简单且众所周知的分布的情况下,如何为看涨期权定价。

这里我们讨论一个更现实的情境。

我们回顾一下,期权的价格满足

\[ P = \beta^n \mathbb E \max\{ S_n - K, 0 \} \]

其中

  1. \(\beta\) 是贴现因子,

  2. \(n\) 是到期日,

  3. \(K\) 是行权价,以及

  4. \(\{S_t\}\) 是标的资产在每个时刻 \(t\) 的价格。

假设 n, β, K = 20, 0.99, 100

假设股票价格满足

\[ \ln \frac{S_{t+1}}{S_t} = \mu + \sigma_t \xi_{t+1} \]

其中

\[ \sigma_t = \exp(h_t), \quad h_{t+1} = \rho h_t + \nu \eta_{t+1} \]

这里 \(\{\xi_t\}\)\(\{\eta_t\}\) 是独立同分布的标准正态随机变量。

(这是一个随机波动率模型,其中波动率 \(\sigma_t\) 随时间变化。)

使用默认值 μ, ρ, ν, S0, h0 = 0.0001, 0.1, 0.001, 10, 0

(这里 S0\(S_0\)h0\(h_0\)。)

通过生成 \(M\) 条路径 \(s_0, \ldots, s_n\),计算蒙特卡洛估计值

\[ \hat P_M := \beta^n \mathbb E \max\{ S_n - K, 0 \} \approx \frac{1}{M} \sum_{m=1}^M \max \{S_n^m - K, 0 \} \]

即价格,应用 Numba 和并行化。