我想用JAX在CPU上加速我的numpy代码,然后在GPU上加速。这是我在本地计算机上运行的示例代码(只有CPU)。
import jax.numpy as jnp
from jax import random, jix
import numpy as np
import time
size = 3000
key = random.PRNGKey(0)
x = random.normal(key, (size,size), dtype=jnp.float64)
start=time.time()
test = jnp.dot(x, x.T).block_until_ready()
print('Time of jnp: {}s'.format(time.time() - start))
x2=np.random.normal((size,size))
start=time.time()
test2 = np.dot(x2, x2.T)
print('Time of np: {}s'.format(time.time() - start))
我得到了一个警告,时间成本如下。
/.../lib/python3.7/site-packages/jax/lib/xla_bridge.py:130:
UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Time: 0.45157814025878906s
Time: 0.005244255065917969s
我在这里做错了什么吗?JAX是否也应该在CPU上加速numpy代码?