相关文章推荐
博学的双杠  ·  c++ - ...·  1 年前    · 
潇洒的香瓜  ·  node.js - ...·  1 年前    · 
避坑一个JIT库numba

避坑一个JIT库numba

Python 语言计算性能是一个很大的问题。所以即时编译是个加速 Python 语言计算性能的有利研究方向,也包含了很多的 JIT 工具包。

其中一个,经常看见各路公众号和自媒体都在推一个 JIT 库: numba

今天我是来吐槽它的。它的使用的确存在非常多的问题。

首先,各种自媒体和公众号在介绍 numba 的时候,都忽略了一个问题,那就是 numba 它只能加速数值计算函数,只能在numpy 包基础上进行加速。

如果你的代码形如如下:

from numba import jit
import pandas as pd
x = {'a': [1, 2, 3], 'b': [20, 30, 40]}
def use_pandas(a): # Function will not benefit from Numba jit
    df = pd.DataFrame.from_dict(a) # Numba 并不理解什么是 pd.DataFrame
    df += 1                        # Numba 并不理解这种操作的意义
    return df.cov()                # or this!
print(use_pandas(x))

numba 是无法处理加速的。这基本上就把 Python 的加速计算限制在了一个比较小的范围之内。应用场景十分有限。

我恰好需要对一个基于 CPU 的 numpy 数值计算进行一个加速,因此尝试用 numba 进行操作,函数如下:

import numpy as np
from numba import jit
@jit(nopython=True)
def get_log_sum(sequence_feature_list, node_weight, dtype=np.float16):
    node_num = len(sequence_feature_list)
    tag_num = node_weight.shape[1]
    # 每个节点的得分
    node_score = np.empty((node_num, tag_num), dtype=dtype)
    for i in range(node_num):
        node_score[i] = np.sum(node_weight[sequence_feature_list[i]], axis=0)
    return node_score

然后我在执行时,就遇到了如下报错:

In definition 12:
    TypeError: unsupported array index type reflected list(int64) in [reflected list(int64)]
    raised from /home/cuichengyu/anaconda3/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
In definition 13:
    TypeError: unsupported array index type reflected list(int64) in [reflected list(int64)]
    raised from /home/cuichengyu/anaconda3/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at /home/cuichengyu/github/jiojio/jiojio/inference.py (342)
File "jiojio/inference.py", line 342:
def get_log_Y_YY(sequence_feature_list, node_weight, dtype=np.float32):
    <source elided>
        # method 2:
        node_score[i] = np.sum(node_weight[sequence_feature_list[i]], axis=0)
  File "/home/cuichengyu/anaconda3/lib/python3.7/site-packages/numba/typing/typeof.py", line 30, in typeof
    ty = typeof_impl(val, c)
  File "/home/cuichengyu/anaconda3/lib/python3.7/functools.py", line 840, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/home/cuichengyu/anaconda3/lib/python3.7/site-packages/numba/typing/typeof.py", line 94, in typeof_type