登录

对大X值不正确的枕curve_fit

内容来源于 Stack Overflow,遵循 CC BY-SA 4.0 许可协议进行翻译与使用。IT领域专用引擎提供翻译支持

腾讯云小微IT领域专用引擎提供翻译支持

原文
Stack Overflow用户 修改于2022-09-23
  • 该问题已被编辑
  • 提问者: Stack Overflow用户
  • 提问时间: 2022-09-22 12:08

为了确定一段时间的趋势,我使用 scipy curve_fit 和来自 time.time() 的X值,例如 1663847528.7147126 (16亿)。做线性插值有时会产生错误的结果,并且提供近似的初始 p0 值也没有帮助。我发现X的大小是造成这个错误的关键因素,我想知道为什么?

下面是一个简单的片段,它显示了工作和不工作的X偏移量:

import scipy.optimize
def fit_func(x, a, b):
    return a + b * x
y = list(range(5))
x = [1e8 + a for a in range(5)]
print(scipy.optimize.curve_fit(fit_func, x, y, p0=[-x[0], 0]))
# Result is correct:
#   (array([-1.e+08,  1.e+00]), array([[ 0., -0.],
#          [-0.,  0.]]))
x = [1e9 + a for a in range(5)]
print(scipy.optimize.curve_fit(fit_func, x, y, p0=[-x[0], 0.0]))
# Result is not correct:
#   OptimizeWarning: Covariance of the parameters could not be estimated
#   warnings.warn('Covariance of the parameters could not be estimated',
#   (array([-4.53788811e+08,  4.53788812e-01]), array([[inf, inf],
#          [inf, inf]]))
Almost perfect p0 for b removes the warning but still curve_fit doesn't work
print(scipy.optimize.curve_fit(fit_func, x, y, p0=[-x[0], 0.99]))
# Result is not correct:
#   (array([-7.60846335e+10,  7.60846334e+01]), array([[-1.97051972e+19,  1.97051970e+10],
#          [ 1.97051970e+10, -1.97051968e+01]]))
# ...but perfect p0 works
print(scipy.optimize.curve_fit(fit_func, x, y, p0=[-x[0], 1.0]))
#(array([-1.e+09,  1.e+00]), array([[inf, inf],
#       [inf, inf]]))

作为一个附带的问题,也许有一个更有效的方法来进行线性拟合?不过,有时我想要找到二阶多项式拟合。

在Windows10下用Python3.9.6和SciPy 1.7.1进行了测试。

浏览 37 关注 0 得票数 1
  • 得票数为Stack Overflow原文数据
原文
修改于2022-09-22
  • 该回答已被编辑
  • 回答者: Stack Overflow用户
  • 回答时间: 2022-09-22 13:15
得票数 1

如果只需要计算线性拟合,我认为 curve_fit 是不必要的,我也可以使用 linregress 函数代替SciPy:

>>> from scipy import stats
>>> y = list(range(5))
>>> x = [1e8 + a for a in range(5)]
>>> stats.linregress(x, y)
LinregressResult(slope=1.0, intercept=-100000000.0, rvalue=1.0, pvalue=1.2004217548761408e-30, stderr=0.0, intercept_stderr=0.0)
>>> x2 = [1e9 + a for a in range(5)]
>>> stats.linregress(x2, y)
LinregressResult(slope=1.0, intercept=-1000000000.0, rvalue=1.0, pvalue=1.2004217548761408e-30, stderr=0.0, intercept_stderr=0.0)

通常,如果需要多项式拟合,我将使用NumPy 多相配合

修改于2022-09-23
  • 该回答已被编辑
  • 回答者: Stack Overflow用户
  • 回答时间: 2022-09-22 15:26
得票数 1

根本原因

你面临着两个问题:

  • 拟合过程对尺度敏感。它指的是在一个特定变量上选择的单位(例如。A(而不是kA)可以人为地阻止算法正确收敛(例如。一个变量比另一个变量大几个数量级,是回归的主导变量);
  • 浮点算术误差当从 1e8 切换到 1e9 时,当这种错误变得突出时,就会达到这个程度。

第二个目标是非常重要的。假设您被限制为8个有效数字表示,那么 1 000 000 000 1 000 000 001 是相同的数字,因为它们都被限制在本文的 1.0000000e9 中,而且我们不能准确地表示需要多一个数字( _ )的 1.0000000_e9 。这就是为什么第二个例子失败的原因。

此外,您正在使用非线性最小二乘算法来解决线性最小二乘问题,这也与您的问题有某种关系。

你有三个解决方案:

  • 正常化;
  • 规范和改变方法/算法;
  • 提高机器精度。

我将选择第一个,因为它更通用,第二个是由 @blunova 提出的,完全有意义,后者可能是一个固有的限制。

归一化

为了缓解这两个问题,一个共同的解决方案是正常化。在您的例子中,一个简单的标准化就足够了:

import numpy as np
import scipy.optimize
y = np.arange(5)
x = 1e9 + y
def fit_func(x, a, b):
    return a + b * x
xm = np.mean(x)         # 1000000002.0
xs = np.std(x)          # 1.4142135623730951
result = scipy.optimize.curve_fit(fit_func, (x - xm)/xs, y)
# (array([2.        , 1.41421356]),
# array([[0., 0.],
#        [0., 0.]]))
# Back transformation:
a = result[0][1]/xs                    # 1.0
b = result[0][0] - xm*result[0][1]/xs  # -1000000000.0

或者使用 sklearn 接口获得相同的结果:

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.linear_model import LinearRegression
pipe = Pipeline([
    ("scaler", StandardScaler()),
    ("regressor", LinearRegression())
pipe.fit(x.reshape(-1, 1), y)
pipe.named_steps["scaler"].mean_          # array([1.e+09])
pipe.named_steps["scaler"].scale_         # array([1.41421356])
pipe.named_steps["regressor"].coef_       # array([1.41421356])
pipe.named_steps["regressor"].intercept_  # 2.0

反变换

实际上,当对拟合结果进行规范化处理时,可以用归一化变量来表示。要获得所需的拟合参数,只需做一些数学运算,就可以将回归的参数转换为原始的变量尺度。

只需写下并解决转换:

 y = x'*a' + b'
x' = (x - m)/s
 y = x*a + b

它给出了以下解决方案:

a = a'/s
b = b' - m/s*a'

精密增编

Numpy默认浮动精度与您预期的一样是 float64 ,大约有15个有效位数:

x.dtype                            # dtype('float64')