14. JAX#

本讲座简要介绍 Google JAX

GPU

本讲座使用了配备 GPU 的机器进行构建——但没有 GPU 也可以运行。

Google Colab 提供了免费的 GPU 访问, 您可以按照以下步骤使用:

  1. 点击右上角的”play”图标

  2. 选择 Colab

  3. 将运行时环境设置为包含 GPU

JAX 是一个高性能科学计算库,提供以下功能:

  • 类似 NumPy 的接口,可以在 CPU 和 GPU 上自动并行化,

  • 一个即时编译器,用于加速大量数值运算,以及

  • 自动微分

JAX 也在日益维护和提供 更多专业化的科学计算例程,例如那些最初在 SciPy 中找到的例程。

除了 Anaconda 中已有的内容外,本讲座还需要以下库:

!pip install jax quantecon

Hide code cell output

Requirement already satisfied: jax in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (0.10.0)
Collecting quantecon
  Downloading quantecon-0.11.2-py3-none-any.whl.metadata (5.3 kB)
Requirement already satisfied: jaxlib<=0.10.0,>=0.10.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (0.10.0)
Requirement already satisfied: ml_dtypes>=0.5.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (0.5.4)
Requirement already satisfied: numpy>=2.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (2.3.5)
Requirement already satisfied: opt_einsum in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (3.4.0)
Requirement already satisfied: scipy>=1.14 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (1.16.3)
Requirement already satisfied: numba>=0.49.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (0.62.1)
Requirement already satisfied: requests in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (2.32.5)
Requirement already satisfied: sympy in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (1.14.0)
Requirement already satisfied: llvmlite<0.46,>=0.45.0dev0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from numba>=0.49.0->quantecon) (0.45.1)
Requirement already satisfied: charset_normalizer<4,>=2 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.11)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2025.11.12)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from sympy->quantecon) (1.3.0)
Downloading quantecon-0.11.2-py3-none-any.whl (330 kB)
Installing collected packages: quantecon
Successfully installed quantecon-0.11.2

我们将使用以下导入:

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import quantecon as qe

注意我们导入了 jax.numpy as jnp,它提供了类似 NumPy 的接口。

14.1. JAX 作为 NumPy 的替代品#

JAX 的一个吸引人之处在于,它的数组处理操作在尽可能的情况下遵循 NumPy API。

这意味着在许多情况下,我们可以将 JAX 作为 NumPy 的直接替代品使用。

让我们来看看 JAX 和 NumPy 之间的异同。

14.1.1. 相似之处#

以下是使用 jnp 进行的一些标准数组操作:

a = jnp.asarray((1.0, 3.2, -1.5))
print(a)
[ 1.   3.2 -1.5]
print(jnp.sum(a))
2.6999998
print(jnp.dot(a, a))
13.490001

然而,数组对象 a 并不是 NumPy 数组:

a
Array([ 1. ,  3.2, -1.5], dtype=float32)
type(a)
jaxlib._jax.ArrayImpl

即使是数组上的标量值映射也会返回 JAX 数组,而不是标量!

jnp.sum(a)
Array(2.6999998, dtype=float32)

14.1.2. 差异#

现在让我们来看看 JAX 和 NumPy 数组操作之间的一些差异。

14.1.2.1. 速度!#

假设我们想在许多点上计算余弦函数。

n = 50_000_000
x = np.linspace(0, 10, n)
14.1.2.1.1. 使用 NumPy#

让我们先用 NumPy 试试:

with qe.Timer():
    # First NumPy timing
    y = np.cos(x)
0.6347 seconds elapsed

再来一次。

with qe.Timer():
    # Second NumPy timing
    y = np.cos(x)
0.6293 seconds elapsed

这里

  • NumPy 使用预编译的二进制文件对浮点数数组应用余弦函数

  • 该二进制文件在本地机器的 CPU 上运行

14.1.2.1.2. 使用 JAX#

现在让我们用 JAX 试试。

x = jnp.linspace(0, 10, n)

让我们对相同的过程计时。

with qe.Timer():
    # First run
    y = jnp.cos(x)
    # Hold the interpreter until the array operation finishes
    jax.block_until_ready(y);
0.1197 seconds elapsed

Note

这里,为了测量实际速度,我们使用 block_until_ready 方法来阻塞解释器,直到计算结果返回。

这是必要的,因为 JAX 使用异步调度,允许 Python 解释器在数值计算之前运行。

对于非计时代码,可以删除包含 block_until_ready 的那一行。

再来计时一次。

with qe.Timer():
    # Second run
    y = jnp.cos(x)
    # Hold interpreter 
    jax.block_until_ready(y);
0.0835 seconds elapsed

在 GPU 上,此代码的运行速度远快于其 NumPy 等效代码。

此外,通常第二次运行比第一次更快,这是由于 JIT 编译的缘故。

这是因为即使是像 jnp.cos 这样的内置函数也是经过 JIT 编译的——第一次运行包含了编译时间。

为什么 JAX 要对像 jnp.cos 这样的内置函数进行 JIT 编译,而不是像 NumPy 那样直接提供预编译版本?

原因是 JIT 编译器希望针对所使用数组的大小(以及数据类型)进行专门优化。

大小对于生成优化代码很重要,因为高效的并行化需要将任务大小与可用硬件相匹配。

14.1.2.2. 大小实验#

我们可以通过更改输入大小并观察运行时间来验证 JAX 针对数组大小进行专门化的说法。

x = jnp.linspace(0, 10, n + 1)
with qe.Timer():
    # First run
    y = jnp.cos(x)
    # Hold interpreter
    jax.block_until_ready(y);
0.1345 seconds elapsed
with qe.Timer():
    # Second run
    y = jnp.cos(x)
    # Hold interpreter
    jax.block_until_ready(y);
0.0859 seconds elapsed

运行时间先增加后减少(这在 GPU 上会更明显)。

这与上面的讨论一致——更改数组大小后的第一次运行显示了编译开销。

关于 JIT 编译的进一步讨论见下文。

14.1.2.3. 精度#

NumPy 和 JAX 之间的另一个差异是 JAX 默认使用 32 位浮点数。

这是因为 JAX 经常用于 GPU 计算,而大多数 GPU 计算使用 32 位浮点数。

使用 32 位浮点数可以在精度损失很小的情况下带来显著的速度提升。

然而,对于某些计算,精度至关重要。

在这些情况下,可以通过以下命令强制使用 64 位浮点数:

jax.config.update("jax_enable_x64", True)

让我们验证这是否有效:

jnp.ones(3)
Array([1., 1., 1.], dtype=float64)

14.1.2.4. 不可变性#

作为 NumPy 的替代品,一个更显著的差异是数组被视为不可变的

例如,在 NumPy 中我们可以这样写:

a = np.linspace(0, 1, 3)
a
array([0. , 0.5, 1. ])

然后在内存中修改数据:

a[0] = 1
a
array([1. , 0.5, 1. ])

在 JAX 中,这会失败!

a = jnp.linspace(0, 1, 3)
a
Array([0. , 0.5, 1. ], dtype=float64)
try:
    a[0] = 1
except Exception as e:
    print(e)
JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

JAX 的设计者选择将数组设为不可变的,因为 JAX 使用函数式编程风格,我们将在下面讨论这一点。

14.1.2.5. 变通方法#

我们注意到 JAX 确实提供了一种替代原地数组修改的方式,使用 at 方法

a = jnp.linspace(0, 1, 3)

应用 at[0].set(1) 会返回一个新的 a 的副本,其中第一个元素被设置为 1:

a = a.at[0].set(1)
a
Array([1. , 0.5, 1. ], dtype=float64)

显然,使用 at 有一些缺点:

  • 语法繁琐,且

  • 每次更改单个值时,我们都希望避免在内存中创建新数组!

因此,在大多数情况下,我们尽量避免使用这种语法。

(尽管它在 JIT 编译的函数中实际上可以很高效——但现在先把这个放在一边。)

14.2. 函数式编程#

来自 JAX 的文档:

当在意大利乡间漫步时,当地人会毫不犹豫地告诉你 JAX 有”una anima di pura programmazione funzionale”(纯函数式编程的灵魂)。

换句话说,JAX 假设采用 函数式编程 风格。

14.2.1. 纯函数#

最主要的含义是 JAX 函数应该是纯函数。

纯函数具有以下特征:

  1. 确定性

  2. 无副作用

确定性意味着:

  • 相同输入 \(\implies\) 相同输出

  • 输出不依赖于全局状态

特别地,纯函数在使用相同输入调用时将始终返回相同的结果。

无副作用意味着函数:

  • 不会改变全局状态

  • 不会修改传递给函数的数据(不可变数据)

14.2.2. 示例——纯函数与非纯函数#

以下是一个非纯函数的示例:

tax_rate = 0.1

def add_tax(prices):
    for i, price in enumerate(prices):
        prices[i] = price * (1 + tax_rate)

prices = [10.0, 20.0]
add_tax(prices)
prices
[11.0, 22.0]

这个函数不是纯函数,因为:

  • 副作用——它修改了全局变量 prices

  • 非确定性——对全局变量 tax_rate 的更改会修改函数输出,即使使用相同的输入数组 prices

以下是一个版本:

def add_tax_pure(prices, tax_rate):
    new_prices = [price * (1 + tax_rate) for price in prices]
    return new_prices

tax_rate = 0.1
prices = (10.0, 20.0)
after_tax_prices = add_tax_pure(prices, tax_rate)
after_tax_prices
[11.0, 22.0]

这是纯函数,因为:

  • 所有依赖关系通过函数参数显式传递

  • 并且不修改任何外部状态

14.2.3. 为什么要函数式编程?#

在 QuantEcon,我们热爱纯函数,因为它们:

  • 有助于测试:每个函数可以独立运行

  • 促进确定性行为,从而提高可重复性

  • 防止由于修改共享状态而产生的错误

JAX 编译器热爱纯函数和函数式编程,因为:

  • 数据依赖关系是显式的,有助于优化复杂计算

  • 纯函数更易于微分(自动微分)

  • 纯函数更易于并行化和优化(不依赖于共享的可变状态)

另一种理解方式如下:

JAX 将函数表示为计算图,然后对其进行编译或变换(例如,微分)。

这些计算图描述了给定的一组输入如何被转换为输出。

JAX 的计算图在构造上是纯粹的。

JAX 使用函数式编程风格,以便用户构建的函数能够直接映射到 JAX 所支持的图论表示中。

14.3. 随机数#

JAX 中的随机数生成与 NumPy 或 MATLAB 中的模式有很大不同。

14.3.1. NumPy / MATLAB 方法#

在 NumPy / MATLAB 中,生成通过维护隐藏的全局状态来工作。

np.random.seed(42)
print(np.random.randn(2))   
[ 0.49671415 -0.1382643 ]

每次我们调用随机函数时,隐藏状态都会被更新:

print(np.random.randn(2)) 
[0.64768854 1.52302986]

这个函数不是纯函数,因为:

  • 它是非确定性的:相同的输入,不同的输出

  • 它有副作用:它修改了全局随机数生成器状态

在并行化下很危险——必须仔细控制每个线程中发生的事情。

14.3.2. JAX#

在 JAX 中,随机数生成器的状态被显式控制。

首先我们生成一个密钥,它为随机数生成器提供种子。

seed = 1234
key = jax.random.key(seed)

现在我们可以使用密钥生成一些随机数:

x = jax.random.normal(key, (3, 3))
x
Array([[-0.54019824,  0.43957585, -0.01978102],
       [ 0.90665474, -0.90831359,  1.32846635],
       [ 0.20408174,  0.93096529,  3.30373914]], dtype=float64)

如果我们再次使用相同的密钥,我们会以相同的种子初始化,因此随机数是相同的:

jax.random.normal(key, (3, 3))
Array([[-0.54019824,  0.43957585, -0.01978102],
       [ 0.90665474, -0.90831359,  1.32846635],
       [ 0.20408174,  0.93096529,  3.30373914]], dtype=float64)

要生成(准)独立的抽取,一种选择是”分裂”现有密钥:

key, subkey = jax.random.split(key)
jax.random.normal(key, (3, 3))
Array([[ 1.24104247,  0.12018902, -2.23990047],
       [ 0.70507261, -0.85702845, -1.24582014],
       [ 0.38454486,  1.32117717,  0.56866901]], dtype=float64)
jax.random.normal(subkey, (3, 3))
Array([[ 0.07627173, -1.30349831,  0.86524323],
       [-0.75550773,  0.63958052,  0.47052126],
       [-1.72866044, -1.14696564, -1.23328892]], dtype=float64)

下图说明了 split 如何从单个根密钥生成密钥树,每个密钥生成独立的随机抽取。

Hide code cell source

fig, ax = plt.subplots(figsize=(8, 4))
ax.set_xlim(-0.5, 6.5)
ax.set_ylim(-0.5, 3.5)
ax.set_aspect('equal')
ax.axis('off')

box_style = dict(boxstyle="round,pad=0.3", facecolor="white",
                 edgecolor="black", linewidth=1.5)
box_used = dict(boxstyle="round,pad=0.3", facecolor="#d4edda",
                edgecolor="black", linewidth=1.5)

# Root key
ax.text(3, 3, "key₀", ha='center', va='center', fontsize=11,
        bbox=box_style)

# Level 1
ax.annotate("", xy=(1.5, 2), xytext=(3, 2.7),
            arrowprops=dict(arrowstyle="->", lw=1.5))
ax.annotate("", xy=(4.5, 2), xytext=(3, 2.7),
            arrowprops=dict(arrowstyle="->", lw=1.5))
ax.text(1.5, 2, "key₁", ha='center', va='center', fontsize=11,
        bbox=box_style)
ax.text(4.5, 2, "subkey₁", ha='center', va='center', fontsize=11,
        bbox=box_used)
ax.text(5.7, 2, "→ draw", ha='left', va='center', fontsize=10,
        color='green')

# Label the split
ax.text(2, 2.65, "split", ha='center', va='center', fontsize=9,
        fontstyle='italic', color='gray')

# Level 2
ax.annotate("", xy=(0.5, 1), xytext=(1.5, 1.7),
            arrowprops=dict(arrowstyle="->", lw=1.5))
ax.annotate("", xy=(2.5, 1), xytext=(1.5, 1.7),
            arrowprops=dict(arrowstyle="->", lw=1.5))
ax.text(0.5, 1, "key₂", ha='center', va='center', fontsize=11,
        bbox=box_style)
ax.text(2.5, 1, "subkey₂", ha='center', va='center', fontsize=11,
        bbox=box_used)
ax.text(3.7, 1, "→ draw", ha='left', va='center', fontsize=10,
        color='green')

ax.text(0.7, 1.65, "split", ha='center', va='center', fontsize=9,
        fontstyle='italic', color='gray')

# Level 3
ax.annotate("", xy=(0, 0), xytext=(0.5, 0.7),
            arrowprops=dict(arrowstyle="->", lw=1.5))
ax.annotate("", xy=(1.5, 0), xytext=(0.5, 0.7),
            arrowprops=dict(arrowstyle="->", lw=1.5))
ax.text(0, 0, "key₃", ha='center', va='center', fontsize=11,
        bbox=box_style)
ax.text(1.5, 0, "subkey₃", ha='center', va='center', fontsize=11,
        bbox=box_used)
ax.text(2.7, 0, "→ draw", ha='left', va='center', fontsize=10,
        color='green')
ax.text(0, 0.65, "split", ha='center', va='center', fontsize=9,
        fontstyle='italic', color='gray')

ax.text(3, -0.5, "⋮", ha='center', va='center', fontsize=14)

ax.set_title("PRNG 密钥拆分树", fontsize=13, pad=10)
plt.tight_layout()
plt.show()
/tmp/ipykernel_2677/3497588807.py:64: UserWarning: Glyph 23494 (\N{CJK UNIFIED IDEOGRAPH-5BC6}) missing from font(s) DejaVu Sans.
  plt.tight_layout()
/tmp/ipykernel_2677/3497588807.py:64: UserWarning: Glyph 38053 (\N{CJK UNIFIED IDEOGRAPH-94A5}) missing from font(s) DejaVu Sans.
  plt.tight_layout()
/tmp/ipykernel_2677/3497588807.py:64: UserWarning: Glyph 25286 (\N{CJK UNIFIED IDEOGRAPH-62C6}) missing from font(s) DejaVu Sans.
  plt.tight_layout()
/tmp/ipykernel_2677/3497588807.py:64: UserWarning: Glyph 20998 (\N{CJK UNIFIED IDEOGRAPH-5206}) missing from font(s) DejaVu Sans.
  plt.tight_layout()
/tmp/ipykernel_2677/3497588807.py:64: UserWarning: Glyph 26641 (\N{CJK UNIFIED IDEOGRAPH-6811}) missing from font(s) DejaVu Sans.
  plt.tight_layout()
/home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 23494 (\N{CJK UNIFIED IDEOGRAPH-5BC6}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
/home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 38053 (\N{CJK UNIFIED IDEOGRAPH-94A5}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
/home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 25286 (\N{CJK UNIFIED IDEOGRAPH-62C6}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
/home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 20998 (\N{CJK UNIFIED IDEOGRAPH-5206}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
/home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 26641 (\N{CJK UNIFIED IDEOGRAPH-6811}) missing from font(s) DejaVu Sans.
  fig.canvas.print_figure(bytes_io, **kw)
_images/7b315a643ff853da32b2e575302f40974d307752effe36641f4f1093fd5c1fe7.png

对于 NumPy 或 Matlab 用户来说,这种语法看起来很不寻常——但当我们进入并行编程时,就会很有意义。

下面的函数使用 split 生成 k 个(准)独立的随机 n x n 矩阵。

def gen_random_matrices(
        key,   # JAX key for random numbers
        n=2,   # Matrices will be n x n
        k=3    # Number of matrices to generate
    ):
    matrices = []
    for _ in range(k):
        key, subkey = jax.random.split(key)
        A = jax.random.uniform(subkey, (n, n))
        matrices.append(A)
    return matrices
seed = 42
key = jax.random.key(seed)
gen_random_matrices(key)
[Array([[0.74211901, 0.54715578],
        [0.05988742, 0.32206803]], dtype=float64),
 Array([[0.65877976, 0.57087415],
        [0.97301903, 0.10138266]], dtype=float64),
 Array([[0.68745522, 0.25974132],
        [0.06595873, 0.83589118]], dtype=float64)]

这个函数是纯函数

  • 确定性的:相同的输入,相同的输出

  • 无副作用:没有隐藏状态被修改

14.3.3. 好处#

如上所述,这种显式性是很有价值的:

  • 可复现性:通过重用密钥轻松重现结果

  • 并行化:控制各个线程上发生的事情

  • 调试:没有隐藏状态使代码更容易测试

  • JIT 兼容性:编译器可以更积极地优化纯函数

14.4. JIT 编译#

JAX 的即时(JIT)编译器通过生成随任务大小和硬件变化的高效机器码来加速执行。

我们在 上文 中已经看到了 JAX 的 JIT 编译器结合并行硬件的强大之处,当时我们对一个大数组应用了 cos 函数。

这里我们研究更复杂函数的 JIT 编译。

14.4.1. 使用 NumPy#

我们先用 NumPy 试试,使用:

def f(x):
    y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
    return y

用较大的 x 运行:

n = 50_000_000
x = np.linspace(0, 10, n)
with qe.Timer():
    # Time NumPy code
    y = f(x)
2.2300 seconds elapsed

急切(Eager)执行模型

  • 每个操作在遇到时立即执行,在下一个操作开始之前将其结果实体化。

缺点

  • 并行化程度最低

  • 较大的内存占用——产生许多中间数组

  • 大量内存读写

14.4.2. 使用 JAX#

作为第一步,我们将整个代码中的 np 替换为 jnp

def f(x):
    y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2
    return y


x = jnp.linspace(0, 10, n)

现在让我们计时。

with qe.Timer():
    # First call
    y = f(x)
    # Hold interpreter
    jax.block_until_ready(y);
1.1239 seconds elapsed
with qe.Timer():
    # Second call
    y = f(x)
    # Hold interpreter
    jax.block_until_ready(y);
0.9014 seconds elapsed

结果与 cos 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。

这是因为单个数组操作在 GPU 上并行化了。

但我们仍在使用即时执行:

  • 由于中间数组导致大量内存占用

  • 大量内存读写

此外,GPU 上还会启动许多独立的内核。

14.4.3. 编译整个函数#

幸运的是,使用 JAX,我们还有另一个技巧——我们可以对整个函数进行 JIT 编译,而不仅仅是单个操作。

编译器将所有数组运算融合到单个优化内核中。

让我们用函数 f 来试试这个:

f_jax = jax.jit(f)
with qe.Timer():
    # First run
    y = f_jax(x)
    # Hold interpreter
    jax.block_until_ready(y);
0.6331 seconds elapsed
with qe.Timer():
    # Second run
    y = f_jax(x)
    # Hold interpreter
    jax.block_until_ready(y);
0.6049 seconds elapsed

运行时间再次改善——现在是因为我们融合了所有操作:

  • 基于整个计算序列的积极优化

  • 消除对硬件加速器的多次调用

内存占用也大大降低——不再创建中间数组。

顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是:

@jax.jit
def f(x):
    pass # put function body here

14.4.4. JIT 编译的工作原理#

当我们对一个函数应用 jax.jit 时,JAX 会对其进行追踪:它不会立即执行操作,而是将操作序列记录为计算图,并将该图交给 XLA 编译器。

XLA 随后将这些操作融合并优化为针对可用硬件(CPU、GPU 或 TPU)定制的单个编译内核。

对 JIT 编译函数的第一次调用会产生编译开销,但对于具有相同输入形状和类型的后续调用,将重用缓存的编译代码并以全速运行。

14.4.5. 编译非纯函数#

虽然 JAX 在编译非纯函数时通常不会抛出错误,但执行会变得不可预测!

以下是一个例子:

a = 1  # global

@jax.jit
def f(x):
    return a + x
x = jnp.ones(2)
f(x)
Array([2., 2.], dtype=float64)

在上面的代码中,全局值 a=1 被融入了 JIT 编译的函数中。

即使我们更改 a,只要调用的是相同的编译版本,f 的输出也不会受到影响。

a = 42
f(x)
Array([2., 2.], dtype=float64)

更改输入的维度会触发函数的重新编译,此时 a 值的变化才会生效:

x = jnp.ones(3)
f(x)
Array([43., 43., 43.], dtype=float64)

这个故事的寓意:使用 JAX 时请编写纯函数!

14.5. 使用 vmap 进行向量化#

JAX 的另一个强大变换是 jax.vmap,它能够自动将针对单个输入编写的函数向量化,使其可以对批量数据进行操作。

这样就无需手动编写向量化代码或使用显式循环。

14.5.1. 一个简单的示例#

假设我们有一个函数,用于计算一组数字的均值与中位数之差。

def mm_diff(x):
    return jnp.mean(x) - jnp.median(x)

我们可以将其应用于单个向量:

x = jnp.array([1.0, 2.0, 5.0])
mm_diff(x)
Array(0.66666667, dtype=float64)

现在假设我们有一个矩阵,希望对每一行计算这些统计量。

不使用 vmap 时,我们需要显式循环:

X = jnp.array([[1.0, 2.0, 5.0],
               [4.0, 5.0, 6.0],
               [1.0, 8.0, 9.0]])

for row in X:
    print(mm_diff(row))
0.6666666666666665
0.0
-2.0

然而,Python 循环速度较慢,无法被 JAX 高效编译或并行化。

使用 vmap,我们可以避免循环,并将计算保留在加速器上:

batch_mm_diff = jax.vmap(mm_diff)    # Create a new "vectorized" version
batch_mm_diff(X)                     # Apply to each row of X
Array([ 0.66666667,  0.        , -2.        ], dtype=float64)

14.5.2. 组合变换#

JAX 的优势之一在于各种变换可以自然地组合使用。

例如,我们可以对向量化函数进行 JIT 编译:

fast_batch_mm_diff = jax.jit(jax.vmap(mm_diff))
fast_batch_mm_diff(X)
Array([ 6.66666667e-01, -2.77555756e-16, -2.00000000e+00], dtype=float64)

jitvmap 以及(我们接下来将看到的)grad 的这种组合是 JAX 设计的核心,使其在科学计算和机器学习领域尤为强大。

14.6. 练习#

Exercise 14.1

在关于 Numba 的讲座的练习部分,我们使用蒙特卡洛方法为欧式看涨期权定价

该代码通过基于 Numba 的多线程进行了加速。

尝试使用所有相同的参数为 JAX 编写此操作的版本。