这里有一个解决方案,可以像scipy pdf那样处理有单行或多行的x。
from scipy.stats import multivariate_normal as mvn
# covariance matrix
sigma = np.array([[2.3, 0, 0, 0],
[0, 1.5, 0, 0],
[0, 0, 1.7, 0],
[0, 0, 0, 2]
# mean vector
mu = np.array([2,3,8,10])
# input
x1 = np.array([2.1, 3.5, 8., 9.5])
x2 = np.array([[2.1, 3.5, 8., 9.5],[2.2, 3.6, 8.1, 9.6]])
def multivariate_normal_pdf(x, mu, cov):
x_m = x - mu
if x.ndim > 1:
sum_ax = 1
t_ax = [0]
t_ax.extend(list(range(x_m.ndim)[:0:-1])) # transpose dims > 0
else:
sum_ax = 0
t_ax = range(x_m.ndim)[::-1]
x_m_t = np.transpose(x_m, axes=t_ax)
A = 1 / ( ((2* np.pi)**(len(mu)/2)) * (np.linalg.det(cov)**(1/2)) )
B = (-1/2) * np.sum(x_m_t.dot(np.linalg.inv(cov)) * x_m,axis=sum_ax)
return A * np.exp(B)
print(mvn.pdf(x1, mu, sigma))