翻译|杨婷
最近,我在处理 PyTorch 分布式和 TorchRec 相关的工作,为此,我开始学习 PyTorch 2.0。在业余时间,我也在跟着Alpa作者学习JAX和XLA。如今回顾这些技术,我发现它们的关注点似乎都是如下两个问题:
包含自动求导和并行在内的函数转换,例如 vmap, pmap 和 pjit 等;
异构计算,CPU 负责控制流,GPU/TPU 负责张量计算和集合通信。
本
文档中
的所有例子
都支
持在 Colab 中运行
:
TensorFlow 1.x
https://colab.research.google.com/drive/1jc0ePg2AAXBihevtoZM_33mmhC70rzqz?usp=sharing
TensorFlow 2.x
https://colab.research.google.com/drive/1PbftzJ9E2_FyIiuozTpExMvlFky_G2nv
PyTorch 1.x
https://colab.research.google.com/drive/1v4hENL-IJ-C6VT5H9W1NC2te85D8VdJK
https://colab.research.google.com/drive/1PlFijLIzAttIBd3tBjiEbSgPXvq9lVlg
functorch/PyTorch 2.x
https://colab.research.google.com/drive/1o-yJ-5g1V084RDaiRw2PqfAjOG7Ty951
“函数转换”意为将一个程序转变成另一个程序,最常见的例子是自动求导(autograd)。自动求导采用用户编写的前向过程并创建后向过程,对于用户来说,编写自动求导通常都太过复杂。函数转换的主要难点在于:在编写函数转换算法时以何种方式表示输入和输出过程。
Theano:显式地构建 IR
Theano是最早的深度学习工具之一,也就是如今为人们所熟知的Aesara项目。Theano有一个允许用户在内存中将IR构建为数据结构的API,因此Theano可实现自动求导,并将结果输出为 Python 函数。
import aesara
from aesara import tensor as at
a = at.dscalar(
"a"
)
# Define placeholders, which have no values.
b = at.dscalar(
"b"
)
c = a * b
# c now contains the IR of an expression.TT
dc = aesara.grad(c, a)
# Convert the IR in c into another one, dc
f_dc = aesara.function([a, b], dc)
# Convert the IR into a Python function,
assert f_dc(1.5, 2.5) == 2.5
# so we can call it.
TensorFlow 1.x:
用于运行 IR 的虚拟机
TensorFlow 1.x明确保留了构建IR的想法。若在TensorFlow中运行上述示例,结果不会有什么差别;但倘若在TensorFlow 1.x中来运行,最大的差别在于:我们不会将后向 IR 转换为 Python 函数,并使用 Python 解释器来运行。相反,我们会在TensorFlow runtime中来运行。
import tensorflow.compat.v1 as tf
# TensorFlow 1.x API
import numpy as np
tf.disable_eager_execution()
a = tf.placeholder(tf.float32, shape=())
b = tf.placeholder(tf.float32, shape=())
c = a * b
dc = tf.gradients(c, [a], stop_gradients=[a, b])
with tf.compat.v1.Session() as sess:
# TensorFlow has a runtime to execute the IR,
x = np.single(2)
# so, no converting it into Python code.
y = np.single(3)
print(sess.run(dc, feed_dict={a:x, b:y}))
PyTorch 1.x:没有前向IR
PyTorch不会像Theano或TensorFlow那样将前向传播转换为IR。反之,PyTorch 使用 Python 解释器来运行前向传播。这样做的弊端在于会在运行期间生成表示后向传播的 IR,我们称之为Eager模式(动态图模式)。
import torch
a = torch.tensor(1.0, requires_grad=True)
# These are not placeholders, but values.
b = torch.tensor(2.0)
c = a * b
# Evaluates c and derives the IR of the backward in c.grad_fn_.
c.backward()
# Executes c.grad_fn_.
print(c.grad)
TensorFlow 2.x: 梯度带
TensorFlow 2.x增加了一个像PyTorch API的Eager模式API。此 API 追踪前向传播如何运行名为梯度带(GradientTape)的 IR 。TensorFlow 2.x可以从这个跟踪中找出后向传播。
import tensorflow as tf
a = tf.Variable(1.0)
# Like PyTorch, these are values, not placehodlers.
b = tf.Variable(2.0)
with tf.GradientTape() as tape:
c = a * b
dcda = tape.gradient(c, a)
print(dcda)
JAX 不会向用户公开诸如梯度带等方面的低级别细节。简单说来,JAX的思维方式为:将输入和输出都用Python函数来表示。
import jax
a = 2.0
b = 3.0
jax.grad(jax.lax.mul)(a, b)
# Compute c = a * b w.r.t. a. The result is b=3.
jax.jit(jax.grad(jax.lax.mul))(a,b)
jax.experimental.pjit(jax.grad(jax.lax.mul),
device_mesh(ntpus))(a,b)
对于想要自己编写的函数转换的高级用户,他们可以调用
make_jaxpr
等低级 API 来访问 IR,称为 JAXPR。
jax
.make_jaxpr
(
jax
.lax
.mul
)(2
.0
, 3
.0
) #
Returns
the
IR
representing
jax
.lax
.mul
(2,3)
jax
.make_jaxpr
(
jax
.grad
(
jax
.lax
.mul
))(2
.0
, 3
.0
) #
Returns
the
IR
of
grad
(
mul
)(2,3)
FuncTorch
FuncTorch
和JAX类似,都是基于PyTorch的函数转换。
import torch, functorch
a = torch.tensor([2.0])
b = torch.tensor([3.0])
functorch.grad(torch.dot)(a, b)
JAX的
make_jaxpr
类似于functorch的
make_fx
。
def f(a, b):
return
torch.dot(a, b) # Have to wrap the builtin
function
dot
into
f
. # 必须将内置函数
dot
转换成
f
.
print
(
functorch.make_fx(f)(a, b).code
)
print
(
functorch.make_fx(functorch.grad(f))(a, b).code
)
TensorFlow 2.x、JAX 和 functorch 都为前向传递构建了一个 IR,但 PyTorch Eager模式没有。IR 不仅可用于自动求导,还可用于其他类型的函数转换。在下列例子中,
functorch.compile.aot_function
调用了回调函数
print_compile_fn
两次,分别用于前向和后向传播。
from
functorch.compile
import
aot_function
import
torch.fx
as
fx
def
print_compile_fn
(fx_module, args)
:
print(fx_module)
return
fx_module
aot_fn = aot_function(torch.dot, print_compile_fn)
aot_fn(a, b)
2
高阶导数
PyTorch
import torch
from torch import autograd
x = torch.tensor(1., requires_grad = True)
y = 2*x**3 + 8
first_derivative = autograd.grad(y, x, create_graph=True)
print(first_derivative)
second_derivative = autograd.grad(first_derivative, x)
print(second_derivative)
TensorFlow 2.x
import
tensorflow
as
tf
x = tf.Variable(
1.0
)
with
tf.GradientTape()
as
outer_tape:
with
tf.GradientTape()
as
tape:
y =
2
*x**
3
+
8
dy_dx = tape.gradient(y, x)
print(dy_dx)
d2y_dx2 = outer_tape.gradient(dy_dx, x)
print(d2y_dx2)
def
f
(a)
:
return
2
*a**
3
+
8
print(jax.grad(f)(
1.0
))
print(jax.grad(jax.grad(f))(
1.0
))
3
动态控制流
动态控制流(dynamic control flows)有两个层级:在 CPU 上运行的粗粒度级别和在 GPU /TPU 上运行的细粒度级别。本部分主要介绍在 CPU 上运行的粗粒度级别的动态控制流。下面我们将用(if/else)条件语句作为例子检验深度学习工具。
TensorFlow 1.x
在 TensorFlow 1.x 中,我们需要将条件语句显式构建到 IR 中。此时条件语句是一个特殊的运算符
tf.cond
。
def
f1
()
:
return
tf.multiply(a,
17
)
def
f2
()
:
return
tf.add(b,
23
)
r = tf.cond(tf.less(a, b), f1, f2)
with
tf.compat.v1.Session()
as
sess:
# TensorFlow has a runtime to execute the IR,
print(sess.run(r, feed_dict={a:x, b:y}))
TensorFlow 2.x
TensorFlow 2.x 支持使用
tf.cond
和
tf.while_loop
显式构建控制流。此外,实验项目google/tangent中有AutoGraph功能,它可以将Python控制流转换为
tf.cond
或
tf.while_loop
。此功能利用了 Python 解释器支持的函数和函数源代码。例如下面的g函数调用了 Python 的标准库将源代码解析为 AST,然后调用 SSA 表单来理解控制流。
def
g
(x, y)
:
if
tf.reduce_any(x < y):
return
tf.multiply(x,
17
)
return
tf.add(y,
23
)
converted_g = tf.autograph.to_graph(g)
import
inspect
print(inspect.getsource(converted_g))
由于部分Python语法很复杂,所以通过解析源代码来理解控制流就显得很困难,这就导致AutoGraph经常出错。但如果这种方法很简单,那么Python开发者社区也不会在构建Python编译器时失败这么多次了。正是由于有这种挑战的存在,必须要明确地将控制流构建到 IR 中。为此,JAX 提供了
jax.lax.cond
和
jax.lax.for_loop
函数。
jax
.lax
.cond
(
a
<
b
,
lambda
:
a
*17,
lambda
:
b
+23)
考虑到这一点,你可能会觉得我们可以使用递归算法。但是下面用于计算阶乘的递归无法用JAX跟踪。
def
factorial
(r, x)
:
return
jax.lax.cond(x <=
1.0
,
lambda
: r,
lambda
: factorial(r*x, x
-1
))
factorial(
1.0
,
3.0
)
可能你还想调用
factorial
来计算
3!=6
。但这会让递归深度超过最大值,因为递归不仅依赖于条件,还依赖于函数定义和调用。
PyTorch
PyTorch最初是Python-native。正如前文所说,由于多功能调度机制,
grad
和
vamp
的函数转换都是即时的。值得注意的是:
相比Theano 和 TensorFlow构建IR后的函数转换,即时函数转换效率更高。
在进行
grad
和
vmap
时,JAX也是即时函数转换。
然而像
pamp
和
pjit
等更复杂的函数转换需要对整个计算过程进行概述,在这个过程中IR是必不可少的。
由于IR在
pmap
和
pjit
中的必要性,PyTorch社区最近添加了
torch.cond
pytorch/pytorch#83154
4
分布式计算
根据执行代码或 IR 的不同方式,在使用 Python 解释器或runtime时,有两种分布式计算方法。
Python-Native
Theano和PyTorch采用了Python-native分布式计算方式。这种分布式训练工作包含多个Python解释器进程。这导致出现了以下结果。
打包和运行(Pack and run)。
由于这些 Python 进程在不同的host上运行,因此我们需要打包用户程序和依赖项,并将它们发送到这些host上去运行。
一直以来TorchX负责了这个打包过程。
它支持例如Docker和torch.package等各种打包格式,并且可以与各种集群管理器配合使用,如Kubernetes和SLURM。
单程序多数据(SPMD)。
由于将用户程序发送到各种host上要依赖于打包,与其他权重较轻的方式(如通过 RPC 发送代码)相比,这种方式不太灵活,因此,我们通常只发送一个程序。
当所有这些进程运行同一程序时,这个作业就变成了单程序多数据(SPMD)作业。
Python-native SPMD
下面是一个简单的SPMD PyTorch程序,我们可以在相同或不同的host上使用进程运行这个程序。在这个过程中,我们只需要调用
all_gather
。真正的分布式训练程序会调用更高级别的API,例如
torch.nn.parallel.DistributedDataParallel
和
torchrec.DistributedModelParallel
, 然后再调用低级 API,例如
all_gather
和
all_reduce
。
import
os
import
torch
from
torch
import
distributed
as
dist
def
main
()
:
use_gpu = torch.cuda.is_available()
local_rank = int(os.environ.get(
"LOCAL_RANK"
,
"0"
))
local_world_size = int(os.environ.get(
"LOCAL_WORLD_SIZE"
,
"0"
))
device = torch.device(
f"cuda:
{local_rank}
"
if
use_gpu
else
"cpu"
)
dist.init_distributed(backend=
"nccl"
)
lst = torch.tensor([local_rank +
100
]).to(device)
# placeholder
rlt_lst = [torch.zeros_like(lst)
for
_
in
range(local_world_size)]
dist.all_gather(rlt_lst, lst, async_op=
False
)
print(
"After broadcasting:"
, rlt_lst)
Python-native Non-SPMD
PyTorch 不仅限于 SPMD 式的分布式训练。它还通过
torch.distributed.pipeline.sync.Pipe
和PiPPy project提供流水并行,其中流水并行的各个阶段在不同的设备上运行不同的程序。这些阶段常通过
torch.rpc
包来沟通。
分布式运行时机制
分布式 TensorFlow 作业由运行 TensorFlow runtime 程序的进程组成,而不是由 Python 解释器组成。此分布式运行时作业执行 TensorFlow graph (IR),它是由执行用户程序的 Python 解释器生成。
用户程序可以使用低级API(如
tf.device
)去指定作业要运行什么操作、在哪台设备和主机上运行等等。因为API有runtime,所以可以做到这一点。
with
tf.device(
'/job:bar/task:0/device:gpu:2'
):
# ops created here have the fully specified device above
与PyTorch一样,TensorFlow也为分布式训练提供了高级API
tf.distributed.strategy
,Keras和DTensor。
strategy = tf.distribute.MirroredStrategy() \
if
tf.config.list_physical_devices(
'GPU'
) \
else
tf.distribute.get_strategy()
with
strategy.scope():
model = tf.keras.Sequential([tf.keras.layers.Dense(
1
, input_shape=(
1
,))])
model.compile(loss=
'mse'
, optimizer=
'sgd'
)
分布式运行时极大地方便了训练服务的维护,因为我们不再将用户程序打包到集群上运行。相反,我们打包运行时程序,因为相比用户程序,运行时程序更加统一。
JAX 支持 Python-native 和分布式运行时。
JAX 提供例如
vmap
、
pmap
和
pjit
的函数转换,这可以将 Python 函数转换为分布式程序。
(本文经授权后由OneFlow社区编译,译文转载请联系获得授权。原文:
https://quip.com/Y8qtAyV4EXRg)
其他人都在看
下载量突破10亿,MinIO的开源启示录
关于ChatGPT的一切;CUDA入门之矩阵乘
李白:你的模型权重很不错,可惜被我没收了
单RTX 3090训练YOLOv5s,时间减少11小时
OpenAI掌门Sam Altman:AI下一个发展阶段
比快更快,开源Stable Diffusion刷新作图速度
OneEmbedding:单卡
训练TB级推荐模型不是梦