% JAX: Autograd and XLA % Jiaao Ho % July 9, 2020
功能和组成
- Autograd: 自动求导,
numpy
自动求导 - JIT: 把
numpy
运算移植到 GPU 和 TPU 上 - vmap: 自动 vectorize
- pmap: 自动 SPMD
自动求导和训练
- 使用
grad
函数实现
from jax import grad
import jax.numpy as np
def tanh(x): # Define a function
y = np.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh) # Obtain its gradient function
print(grad_tanh(1.0)) # Evaluate it at x = 1.0
# prints 0.4199743
使用 JIT 来使一个函数变快
- 直接作为函数使用
import jax.numpy as np
from jax import jit
def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = np.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
- 使用 python 的
@
函数修饰
@jit
def update(i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
opt_state = opt_init(init_params)
for i in range(num_steps):
opt_state = update(i, opt_state, next(batches))
自动向量化和 SPMD
- 向量化 -> 自动增加 batch 维度
from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
- 自动在多 GPU 上进行计算 (需要指定 axis).
from jax import random, pmap
import jax.numpy as np
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: np.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print(pmap(np.mean)(result))
- 自动使用集合通信
from functools import partial
from jax import lax
@partial(pmap, axis_name='i')
def normalize(x):
return x / lax.psum(x, 'i')
print(normalize(np.arange(4.)))
# prints [0. 0.16666667 0.33333334 0.5 ]
What’s new against TensorFlow?
- 逻辑更简单, 都是直观的计算代码
- 强调 JIT 的使用, 保证运算效率