避坑一个JIT库numba
Python 语言计算性能是一个很大的问题。所以即时编译是个加速 Python 语言计算性能的有利研究方向,也包含了很多的 JIT 工具包。
其中一个,经常看见各路公众号和自媒体都在推一个 JIT 库: numba 。
今天我是来吐槽它的。它的使用的确存在非常多的问题。
首先,各种自媒体和公众号在介绍 numba 的时候,都忽略了一个问题,那就是 numba 它只能加速数值计算函数,只能在numpy 包基础上进行加速。
如果你的代码形如如下:
from numba import jit
import pandas as pd
x = {'a': [1, 2, 3], 'b': [20, 30, 40]}
def use_pandas(a): # Function will not benefit from Numba jit
df = pd.DataFrame.from_dict(a) # Numba 并不理解什么是 pd.DataFrame
df += 1 # Numba 并不理解这种操作的意义
return df.cov() # or this!
print(use_pandas(x))
numba 是无法处理加速的。这基本上就把 Python 的加速计算限制在了一个比较小的范围之内。应用场景十分有限。
我恰好需要对一个基于 CPU 的 numpy 数值计算进行一个加速,因此尝试用 numba 进行操作,函数如下:
import numpy as np
from numba import jit
@jit(nopython=True)
def get_log_sum(sequence_feature_list, node_weight, dtype=np.float16):
node_num = len(sequence_feature_list)
tag_num = node_weight.shape[1]
# 每个节点的得分
node_score = np.empty((node_num, tag_num), dtype=dtype)
for i in range(node_num):
node_score[i] = np.sum(node_weight[sequence_feature_list[i]], axis=0)
return node_score
然后我在执行时,就遇到了如下报错:
In definition 12:
TypeError: unsupported array index type reflected list(int64) in [reflected list(int64)]
raised from /home/cuichengyu/anaconda3/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
In definition 13:
TypeError: unsupported array index type reflected list(int64) in [reflected list(int64)]
raised from /home/cuichengyu/anaconda3/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at /home/cuichengyu/github/jiojio/jiojio/inference.py (342)
File "jiojio/inference.py", line 342:
def get_log_Y_YY(sequence_feature_list, node_weight, dtype=np.float32):
<source elided>
# method 2:
node_score[i] = np.sum(node_weight[sequence_feature_list[i]], axis=0)
File "/home/cuichengyu/anaconda3/lib/python3.7/site-packages/numba/typing/typeof.py", line 30, in typeof
ty = typeof_impl(val, c)
File "/home/cuichengyu/anaconda3/lib/python3.7/functools.py", line 840, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/home/cuichengyu/anaconda3/lib/python3.7/site-packages/numba/typing/typeof.py", line 94, in typeof_type