static double CumDensity(double z)
double p = 0.3275911;
double a1 = 0.254829592;
double a2 = -0.284496736;
double a3 = 1.421413741;
double a4 = -1.453152027;
double a5 = 1.061405429;
int sign;
if (z < 0.0)
sign = -1;
sign = 1;
double x = Math.Abs(z) / Math.Sqrt(2.0);
double t = 1.0 / (1.0 + p * x);
double erf = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) *
t * Math.Exp(-x * x);
return 0.5 * (1.0 + (sign * erf));
编写概率分类代码的第二个挑战是确定权重的值,以便在显示训练数据时,计算出的输出值可以非常接近已知的输出值。研究这一问题的另一种途径是将已计算输出值和已知输出值之间的误差最小化作为目标。这就是所谓的使用数值优化训练模型。
没有简单的方法来训练大多 ML 分类,包括概率分类。您还可以使用大约十几个主要技术,每种技术都有几十种变体。常见的训练技术包括简单的梯度下降法、反向传播法、牛顿迭代法、粒子群优化算法、进化优化算法和 L-BFGS 算法。该演示程序使用最古老同时也是最简单的一种训练技术 - 单纯形优化法。
了解单纯形优化法
粗略地讲,单纯形是一个三角形。单纯形优化法背后的理念是先从三个可能的解决方案(因而为“单纯形”)开始入手。一种解决方案将是“最好的”(误差最小),一种将是“最差的”(误差最大),第三种被称为“其他”。接着,单纯形优化法创建三个称为“扩大”、“反射”和“收缩”的新潜在解决方案。它们之中的每一个都要与当前最差的解决方案进行比较,如果有新的更好(误差较小)的出现,那么当前最差的解决方案将被取代。
图 4 中说明了单纯形优化法。在解决方案包括两个值的简单情况中,如(1.23,4.56),您可以将解决方案认为是在 (x, y) 平面上的一个点。图 4 的左侧显示了如何从当前的“最好的”、“最差的”和“其他”解决方案中生成三个新的候选解决方案。
图 4 单纯形优化法
首先,计算质心。质心是“最好的”和“其他”解决方案取平均后的平均值。在两个维度中,这是介于“最好”点和“其他”点之间的中间位置。接着,创建假想线,它始于“最差”点,并通过质心延伸出去。“收缩”候选点介于“最差”点和质心点之间。“反射”候选点在假想线上,穿过质心。“扩大”候选点穿过“反射”点。
在单纯形优化法的每次迭代中,如果“扩大”、“反射”或“收缩”候选点比当前“最差的”解决方案还好,那么将使用该候选点来替换“最差”点。如果生成的三个候选点都没有“最差的”解决方案好,那么当前“最差的”和“其他”解决方案将向“最好的”解决方案方向移动,来到介于当前位置和“最好的”解决方案之间的某点,如图 4 的右手边所示。
每次迭代结束后,将形成新的虚拟的“最好的 - 其他 - 最差的”三角形,从而越来越接近最佳的解决方案。如果可以拍摄每个三角形的快照,那么在按顺序查看时,移动的三角形就像是一个尖尖的 blob 在以类似于单细胞变形虫的方式在平面上前进着。出于这个原因,单纯形优化法有时也被称为变形虫方法优化。
单纯形优化法有许多变体,但它们的“收缩”、“反射”和“扩大”候选解决方案距当前质心处的距离不同,并且会检查这些候选解决方案的顺序,以查看它们是否比当前“最差的”解决方案好。单纯形优化法的最常见形式即所谓的 Nelder-Mead 算法。该演示程序使用一个不具备特定名称的更简单变体。
对于概率分类,每一个潜在解决方案都是一组权重值。图 5 以伪代码的形式显示演示程序中所用的单纯形优化法变体。
图 5 演示程序中所用的单纯形优化法的伪代码
randomly initialize best, worst other solutions
loop maxEpochs times
create centroid from worst and other
create expanded
if expanded is better than worst, replace worst with expanded,
continue loop
create reflected
if reflected is better than worst, replace worst with reflected,
continue loop
create contracted
if contracted is better than worst, replace worst with contracted,
continue loop
create a random solution
if random solution is better than worst, replace worst,
continue loop
shrink worst and other toward best
end loop
return best solution found
就像所有其他 ML 优化算法一样,单纯形优化法既有优点又有缺点。然而,它可以简单地加以实现(相对而言),通常在实践中效果很好(尽管并非总是如此)。
为了创建演示程序,我启动了 Visual Studio 并选择了 C# 控制台应用程序模板,并把它命名为 ProbitClassification。该演示对 Microsoft .NET Framework 版本的依赖程度并不明显,因此,任何相对较新的 Visual Studio 版本都应该可以正常工作。加载模板代码后,在解决方案资源管理器窗口,我将文件 Program.cs 重命名为 ProbitProgram.cs,而 Visual Studio 会自动重命名类程序。
图 6 演示代码的开头部分
using System;
namespace ProbitClassification
class ProbitProgram
static void Main(string[] args)
Console.WriteLine("\nBegin Probit Binary Classification demo");
Console.WriteLine("Goal is to predict death (0 = false, 1 = true)");
double[][] data = new double[30][];
data[0] = new double[] { 48, +1, 4.40, 0 };
data[1] = new double[] { 60, -1, 7.89, 1 };
// Etc.
data[29] = new double[] { 68, -1, 8.38, 1 };
图 6 中显示了演示代码的开头部分。虚拟数据被硬编码到程序中。在非演示方案中,您的数据将被存储在一个文本文件中,您必须编写一个实用程序方法将数据加载到内存中。接下来,使用程序定义的 helper 方法“ShowData”显示源数据:
Console.WriteLine("\nRaw data: \n");
Console.WriteLine(" Age Sex Kidney Died");
Console.WriteLine("=======================================");
ShowData(data, 5, 2, true);
接着,将对 0 列和 2 列中的源数据进行规范化:
Console.WriteLine("Normalizing age and kidney data");
int[] columns = new int[] { 0, 2 };
double[][] means = Normalize(data, columns); // Normalize, save means & stdDevs
Console.WriteLine("Done");
Console.WriteLine("\nNormalized data: \n");
ShowData(data, 5, 2, true);
经过规范化的方法保存并返回所有列的平均值和标准偏差,使得在遇到新的数据时,可以使用用于训练模型的相同参数对该数据进行规范化。接着,规范化后的数据被分为训练集 (80%) 和测试集 (20%):
Console.WriteLine("Creating train (80%) and test (20%) matrices");
double[][] trainData;
double[][] testData;
MakeTrainTest(data, 0, out trainData, out testData);
Console.WriteLine("Done");
Console.WriteLine("\nNormalized training data: \n");
ShowData(trainData, 3, 2, true);
您可能想将方法“MakeTrainTest”参数化,以接受要放在训练集中的项目百分比。接下来,将程序定义的概率分类器对象实例化:
int numFeatures = 3; // Age, sex, kidney
Console.WriteLine("Creating probit binary classifier");
ProbitClassifier pc = new ProbitClassifier(numFeatures);
然后,通过使用单纯形优化法查找权重值来训练概率分类器,以便计算出的输出值能够非常接近已知输出值:
int maxEpochs = 100; // 100 gives a representative demo
Console.WriteLine("Setting maxEpochs = " + maxEpochs);
Console.WriteLine("Starting training");
double[] bestWeights = pc.Train(trainData, maxEpochs, 0);
Console.WriteLine("Training complete");
Console.WriteLine("\nBest weights found:");
ShowVector(bestWeights, 4, true);
最后,演示程序根据训练数据和测试数据计算模型的分类准确性:
double testAccuracy = pc.Accuracy(testData, bestWeights);
Console.WriteLine("Prediction accuracy on test data =
" + testAccuracy.ToString("F4"));
Console.WriteLine("\nEnd probit binary classification demo\n");
Console.ReadLine();
} // Main
演示并未对前所未见的数据做出预测。进行的预测类似如下:
// Slightly older, male, higher kidney score
double[] unknownNormalized = new double[] { 0.25, -1.0, 0.50 };
int died = pc.ComputeDependent(unknownNormalized, bestWeights);
if (died == 0)
Console.WriteLine("Predict survive");
else if (died == 1)
Console.WriteLine("Predict die");
此代码假定独立 x 变量(年龄、性别和肾脏检测分数)已经使用从训练数据规范化过程中得到的平均值和标准偏差进行了规范化。
ProbitClassifier 类
图 7 中显示的是 ProbitClassifier 类的整体结构。该 ProbitClassifier 定义中包含一个名为“Solution”的嵌套类。该子类从 IComparable 接口中派生出来,使包含三个“Solution”对象的数组可以自动进行排序,以提供“最好的”、“其他”和“最差的”解决方案。通常我并不喜欢花哨的编码技术,但在该情况中,好处略胜于增加的复杂性。
图 7 ProbitClassifier 类
public class ProbitClassifier
private int numFeatures; // Number of independent variables
private double[] weights; // b0 = constant
private Random rnd;
public ProbitClassifier(int numFeatures) { . . }
public double[] Train(double[][] trainData, int maxEpochs, int seed) { . . }
private double[] Expanded(double[] centroid, double[] worst) { . . }
private double[] Contracted(double[] centroid, double[] worst) { . . }
private double[] RandomSolution() { . . }
private double Error(double[][] trainData, double[] weights) { . . }
public void SetWeights(double[] weights) { . . }
public double[] GetWeights() { . . }
public double ComputeOutput(double[] dataItem, double[] weights) { . . }
private static double CumDensity(double z) { . . }
public int ComputeDependent(double[] dataItem, double[] weights) { . . }
public double Accuracy(double[][] trainData, double[] weights) { . . }
private class Solution : IComparable<Solution>
// Defined here
ProbitClassifier 有两种输出方法。方法“ComputeOutput”返回一个介于 0.0 和 1.0 之间的值,并在训练期间用于计算错误值。方法“ComputeDependent”是一个围绕 ComputeOutput 的包装,如果输出小于或等于 0.5 则返回 0,如果输出大于 0.5 则返回 1。这些返回值被用来计算准确性。
概率分类是最古老的 ML 技术之一。由于概率分类与逻辑回归分类颇为相似,普遍的做法是使用二者中任一种技术。由于 LR 比概率更容易实现一些,所以概率分类不常使用,随着时间的推移已经有点沦落为二等 ML 公民。不过,概率分类往往是非常有效的,可以为您的 ML 工具带来宝贵的价值。
James McCaffrey 博士 任职于华盛顿州雷德蒙德市的 Microsoft 研究中心。他长期从事多个 Microsoft 产品(包括 Internet Explorer 和 Bing)的研发工作。可以在 jammc@microsoft.com 上联系 McCaffrey 博士。
衷心感谢以下 Microsoft Research 技术专家对本文的审阅:Nathan Brown 和 Kirk Olynyk。