% JAX: Autograd and XLA % Jiaao Ho % July 9, 2020

功能和组成

  • Autograd: 自动求导, numpy 自动求导
  • JIT: 把 numpy 运算移植到 GPU 和 TPU 上
  • vmap: 自动 vectorize
  • pmap: 自动 SPMD

自动求导和训练

  • 使用 grad 函数实现
from jax import grad
import jax.numpy as np

def tanh(x):  # Define a function
  y = np.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient function
print(grad_tanh(1.0))   # Evaluate it at x = 1.0
# prints 0.4199743 

使用 JIT 来使一个函数变快

  • 直接作为函数使用
import jax.numpy as np
from jax import jit

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = np.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x)  # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x)  # ~ 14.5 ms / loop (also on GPU via JAX)
  • 使用 python 的 @ 函数修饰
  @jit
  def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

  opt_state = opt_init(init_params)
  for i in range(num_steps):
    opt_state = update(i, opt_state, next(batches))

自动向量化和 SPMD

  • 向量化 -> 自动增加 batch 维度
from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
  • 自动在多 GPU 上进行计算 (需要指定 axis).
from jax import random, pmap
import jax.numpy as np

# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: np.dot(x, x.T))(mats)  # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(np.mean)(result))
  • 自动使用集合通信
from functools import partial
from jax import lax

@partial(pmap, axis_name='i')
def normalize(x):
  return x / lax.psum(x, 'i')

print(normalize(np.arange(4.)))
# prints [0.         0.16666667 0.33333334 0.5       ]

What’s new against TensorFlow?

  • 逻辑更简单, 都是直观的计算代码
  • 强调 JIT 的使用, 保证运算效率