相关文章推荐
开朗的楼梯  ·  python - ...·  1 年前    · 
完美的牛肉面  ·  Delphi VCL ...·  1 年前    · 

为什么jax.numpy.dot()在CPU上的运行速度比numpy.dot()慢?

2 人关注

我想用JAX在CPU上加速我的numpy代码,然后在GPU上加速。这是我在本地计算机上运行的示例代码(只有CPU)。

import jax.numpy as jnp
from jax import random, jix
import numpy as np
import time
size = 3000
key = random.PRNGKey(0)
x =  random.normal(key, (size,size), dtype=jnp.float64)
start=time.time()
test = jnp.dot(x, x.T).block_until_ready()
print('Time of jnp: {}s'.format(time.time() - start))
x2=np.random.normal((size,size))
start=time.time()
test2 = np.dot(x2, x2.T)
print('Time of np: {}s'.format(time.time() - start))

我得到了一个警告,时间成本如下。

/.../lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: 
UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Time: 0.45157814025878906s
Time: 0.005244255065917969s

我在这里做错了什么吗?JAX是否也应该在CPU上加速numpy代码?

5 个评论
很有可能,numpy正在使用(Open-)。 BLAS 而对于 np.dot() 来说,没有太多需要优化的地方。
@sascha 但JAX比NumPy慢得多,这说不通。我还没有搞清楚原因。
This does not surprise me. Dot (vector or matmul; does not matter) is completely hand-coded for every kind of cpu-arch and you won't beat this with automatic-compilers. Without much knowledge about JAX, it's probably about scheduling, optimizing temporaries and some other stuff -> resulting in great code when multiple "kernels" are fused . But dot 是如此重要,以至于没有机会达到它。补充说明:我提到了openblas:但在一些地区,Intels MKL被使用:你期望一些自动编译器能在你的(也许)intel cpu上击败手工编码的matmul代码(由intel-devs)吗?
还请阅读。 this
jkr
也许numpy也更快,因为 x 的形状是 (3000, 3000) ,而 x2 的形状是 (2,) :)
python
numpy
f. c.
f. c.
发布于 2020-08-31
1 个回答
jkr
jkr
发布于 2020-09-01
已采纳
0 人赞同

Jax和Numpy之间可能存在性能差异,但在原帖中,时间差异主要归结为数组创建中的一个错误。Jax使用的数组的形状是3000x3000,而Numpy使用的数组是一个长度为2的一维数组。 numpy.random.normal loc (即要取样的高斯的平均值)。应使用关键字参数 size= 来表示阵列的形状。

numpy.random.normal(loc=0.0, scale=1.0, size=None)

一旦做了这个改变,Jax和Numpy之间的性能差异就会变小。

import time
import jax
import jax.numpy as jnp
import numpy as np
size = 3000
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size, size), dtype=jnp.float64)
start = time.time()
test = jnp.dot(x, x.T).block_until_ready()
print("Time of jnp: {:0.4f} s".format(time.time() - start))
x2 = np.random.normal(size=(size, size)).astype(np.float64)
start = time.time()
test2 = np.dot(x2, x2.T)
print("Time of np: {:0.4f} s".format(time.time() - start))

The output of one run is

Time of jnp: 2.3315 s
Time of np: 2.8811 s

当测量计时性能时,应该收集多次运行,因为一个函数的性能是一个分散的次数,而不是一个单一的值。这可以通过Python标准库来完成timeit.timeit function or the %timeit在IPython和Jupyter Notebook中的魔法。

import time
import jax
import jax.numpy as jnp
import numpy as np
size = 3000
key = jax.random.PRNGKey(0)
xjnp = jax.random.normal(key, shape=(size, size), dtype=jnp.float64)
xnp = np.random.normal(size=(size, size)).astype(np.float64)
%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.03 s ± 39.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.dot(xnp, xnp.T)
# 3.41 s ± 501 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
xjnp = xjnp.astype(jnp.float32)
xnp = xnp.astype(np.float32)
%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.05 s ± 74.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)