因果森林总结:基于树模型的异质因果效应估计
本文中各类forest-based methods主要从split和predict两个角度展开,忽略渐进高斯性等理论推导。
一、Random Forest
传统随机森林由多棵决策树构成,每棵决策树在第i次split的时候,分裂准则如下(这里关注回归树):
\arg\min_{\prod_i} \;MSE_i=\frac{1}{n}\sum_{i=1}^n(Y_i-\bar{Y}_{j:x_j\in l(x_i|\prod_i)}) \tag{1}
其中 l(x_i|\prod_i) 表示在 \prod_i 的划分情况下, x_i 所在的叶子结点。
随机森林构建完成后,给定测试数据 x ,预测值为:
\begin{align} \hat{\mu}(x)=\frac{1}{B}\sum_{b=1}^{B}\hat{\mu}_b(x) \tag{2}\\ \hat{\mu}_b(x)=\bar{Y}_{j:x_j\in l(x_i|\prod_b)} \tag{3} \end{align}
二、Causal Forest
类似地,因果森林由多棵因果树构成,由于需要Honest estimation(用互不重合的数据 S^{tr},S^{est} 分别进行split和estimate),因此相较于决策树,每棵因果树split的分裂准则修改如下:
\begin{align} \arg\max_{\prod} -\hat{EMSE}_{\tau}(S^{tr},S^{est},\prod)=\frac{1}{N^{tr}}\sum_{i\in S^{tr}}\hat{\tau}^2(X_i;S^{tr},\prod)-(\frac{1}{N^{tr}}+\frac{1}{N^{est}})\sum_{l\in \prod}(\frac{S^2_{S^{tr}_{treat}}(l)}{p}+\frac{S^2_{S^{tr}_{control}}(l)}{1-p}) \tag{4} \end{align}
其中
\begin{align} \hat{\tau}(x;S^{tr},\prod)&=\hat{\mu}(1,x;S^{tr},\prod)-\hat{\mu}(0,x;S^{tr},\prod) \tag{5}\\ \hat{\mu}(w,x;S^{tr},\prod)&=\bar{Y}_{i:i\in S^{tr}_w,X_i\in l(x|\prod)} \tag{6} \end{align}
在叶子结点内可以认为所有样本同质,所以因果森林构建完成后,给定测试数据 x ,其预测值为:
\begin{align} \hat{\tau}(x)&=\frac{1}{B}\sum_{b=1}^{B}\hat{\tau}_b(x) \tag{7}\\ \hat{\tau}_b(x)&=\bar{Y}_{i:i\in S^{est}_1,X_i\in l(x|\prod_b)}-\bar{Y}_{i:i\in S^{est}_0,X_i\in l(x|\prod_b)} \tag{8} \end{align}
三、Generalized Random Forest
广义随机森林可以看作是对随机森林进行了推广:原来随机森林只能估计观测目标值 Y ,现在广义随机森林可以估计任何感兴趣的指标 \theta(\cdot) 。
3.1 predict
先假设我们在已经有一棵训练好的广义随机森林,现在关注给定测试数据 x ,如何预测我们感兴趣的指标?
通过公式(2)和(3),传统随机森林预测的做法是:
- 在单棵树中,将测试数据 x 所在叶子结点的观测目标值取平均作为该树对 x 的预测;
- 在多棵树中,将单棵树的不同预测结果取平均作为最终的预测结果。
而在广义随机森林中,首先基于因果森林得到各数据 x_i 相对于测试数据 x 的权重 \alpha_i ,之后加权求解局部估计等式,具体地:
-
权重估计阶段:将数据
x_i
与测试数据
x
在同一叶子结点中的""共现频率""作为其权重,如下:
\begin{align} \alpha_{bi}(x)&=\frac{1\{x_i\in l_b(x)\}}{|l_b(x)|} \tag{9}\\ \alpha_i(x)&=\frac{1}{B}\sum_{b=1}^B\alpha_{bi}(x) \tag{10} \end{align} -
加权求解局部估计等式阶段:下式中
\theta(x)
表示我们感兴趣的参数,
v(x)
表示我们不感兴趣但必须估计的参数,
O
表示观测到的与我们感兴趣的参数相关的值。
\begin{align} \hat{\theta}(x),\hat{v}(x)\in \arg\min_{\theta,v} ||\sum_{i=1}^n\alpha_i(x)\psi_{\theta,v}(O_i)||_2 \tag{11} \end{align}
在predict阶段,我们可以证明,随机森林恰好是广义随机森林的一个特例,证明如下:
-
首先,在随机森林的setting下,
O_i=Y_i
,我们感兴趣的参数恰好是
\theta(x)=\mu(x)=E[Y_i|X_i=x]
;
-
极大似然函数为
\mu(x)=\arg\min_{\mu} E[(Y_i-\mu)^2|X_i=x]
,其score function为
\psi_{\mu(x)}(Y_i)=Y_i-\mu(x)
;
-
因此公式(11)为:
\hat{\mu}(x)=\arg\min_{\mu}||\sum_{i=1}^n\alpha_i(x)\psi_{\mu}(Y_i)||_2=\arg\min_{\mu}(\sum_{i=1}^n\alpha_i(x)(Y_i-\mu))^2
;
-
因此有:
\sum_{i=1}^n\alpha_i(x)(Y_i-\hat{\mu}(x))=0
,可得:
\begin{align} \hat{\mu}(x)&=\sum_{i=1}^n\alpha_i(x)Y_i \\ &=\sum_{i=1}^n\sum_{b=1}^B\frac{1}{B}\frac{1\{x_i\in l_b(x)\}}{|l_b(x)|}Y_i\\ &=\frac{1}{B}\sum_{b=1}^{B}\sum_{i=1}^n\frac{Y_i1\{x_i\in l_b(x)\}}{|l_b(x)|}\\ &=\frac{1}{B}\sum_{b=1}^{B}\hat{\mu}_b(x) \end{align} \\
3.2 split
首先,由于广义随机森林的目标是准确估计感兴趣的参数 \theta(x) ,因此针对单一节点 P 与一组样数据 J ,估计参数 \theta,v 的方法是:
(\hat{\theta}_P,\hat{v}_P)(J)\in \arg\min_{\theta,v} ||\sum_{\{i\in J:X_i\in P\}}\psi_{\theta,v}(O_i)||_2 \tag{12}
接着,我们要将节点P分裂为两个子节点 C_1,C_2 ,分裂的目标是极小化感兴趣的参数的误差:
err(C_1,C_2)=\sum_{i=1}^2P(X\in C_j|X\in P)E[(\hat{\theta}_{C_j}(J)-\theta(X))^2|X\in C_j] \\
但是实际上 \theta(X) 是不可见的,经过一番推导,最终可以发现最小化 err(C_1,C_2) 等价于最大化下面的公式:
\triangle(C_1,C_2)=\frac{n_{C_1}n_{C_2}}{n^2_{C_P}}(\hat{\theta}_{C_1}(J)-\hat{\theta}_{C_2}(J)) \tag{13}
也就是说,最小化感兴趣的参数的误差等价于最大化两个子节点的异质性。
如果每个 \hat{\theta}_C 都通过求解式(12)获得,那算法的计算复杂度非常高,因此可以通过gradient-based的方法去得到 \hat{\theta}_C 的近似解:
\begin{align} \tilde{\theta}_C&=\hat{\theta}_P-\frac{1}{|\{i:X_i\in C\}|}\sum_{\{i:X_i\in C\}}\xi^TA_P^{-1}\psi_{\hat{\theta}_P,\hat{v}_P}(O_i) \\ A_P&=\frac{1}{|\{i:X_i\in C\}|}\sum_{\{i:X_i\in C\}}\nabla \psi_{\hat{\theta}_P,\hat{v}_P}(O_i) \end{align} \\
至此,我们可以将split分成两个阶段:
-
标记阶段:计算父节点的
\hat{\theta}_P,\hat{v}_P,A_P^{-1}
,之后针对每个样本计算虚拟的目标值:
\rho_i=-\xi^TA_P^{-1}\psi_{\hat{\theta}_P,\hat{v}_P}(O_i) \tag{14} -
回归阶段:分裂准则为最大化式(14):
\tilde{\triangle}(C_1,C_2)=\sum_{j=1}^2\frac{1}{{|\{i:X_i\in C_j\}|}}(\sum_{\{i:X_i\in C_j\}}\rho_i)^2 \tag{15}
在split阶段,也可以证明随机森林是广义随机森林的一个特例:
- 首先,在随机森林的setting下,score function为 \psi_{\mu(x)}(Y_i)=Y_i-\mu(x) ;
- 此时 A_P=\frac{1}{n_P}\sum_{\{i:X_i\in P\}}(-1)=-1 , \rho_i=-1^T(-1)^{-1}[Y_i-\hat{\mu}_P(x)]=Y_i-\bar{Y}_P 。
3.3 局部估计等式
在广义随机森林中,假设下列的数据产生过程:
\begin{align} Y&=\theta(x)\cdot W+v(x)+\epsilon,\;\;E[\epsilon|X]=0\\ W&=f(X)+\eta,\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;E[\eta|X]=0,\;E[\epsilon,\eta|X]=0 \end{align} \\
这里 O_i=(Y_i,W_i) ,有:
\begin{align} \psi_{\theta(x),v(x)}(O_i)&=\psi_{\theta(x),v(x)}(Y_i,W_i)\\ &=Y_i-\theta(x)\cdot W_i-v(x) \tag{16} \end{align}
此时 \arg\min_{\theta}||E[\psi_{\theta(x),v(x)}(O_i)]||_2 相当于:
\theta(x)=\xi^T\frac{Cov(W_i,Y_i|X_i=x)}{Var(W_i|X_i=x)} \tag{17}
带上权重 \alpha_i(x) 的时候类似。
3.4 other
causal forest和generalized random forest的分裂准则其实是等价的,只不过式(4)考虑了下式的b和c两部分,式(13)/(15)只考虑了b部分:
err(C_1,C_2)=\underbrace{K(P)}_{a}-\underbrace{E[\triangle(C_1,C_2)]}_{b}+\underbrace{o(r^2)}_c \\
四、Orthogonal Random Forest
orthogonal random forest只是在generalized random forest的基础上进行了两个改动:
-
加了DML:在一开始先拟合
E[Y|X],E[W|X]
,得到残差(first stage);再对残差跑generalized random forest(second stage)。与广义随机森林的score function(16)相比,正交随机森林的score function的定义如下,
\begin{align} \psi_{\theta(x),v(x)}(O_i)&=\psi_{\theta(x),v(x)}(Y_i,W_i)\\ &=Y_i-E[Y_i|X]-\theta(x)\cdot (W_i-E[W_i|X])-v(x) \tag{19} \end{align}
此时
\arg\min_{\theta}||E[\psi_{\theta(x),v(x)}(O_i)]||_2
相当于:
\theta(x)=\xi^T\frac{Cov(W_i-E[W_i|X],Y_i-E[Y_i|X]|X_i=x)}{Var(W_i-E[W_i|X]|X_i=x)} \tag{20}
带上权重
\alpha_i(x)
的时候类似。
-
在predict阶段强调locally,即拟合
E[Y|X],E[W|X]
的时候(DML的first stage)使用上权重
\alpha_i(x)
。
五、TODO
记录一个还没想明白的问题,路过的大佬有懂的欢迎讨论~
到这里我们可以发现一个节点内的数据的HTE有两种计算方式:
- 一种是如式(8)所示,直接计算不同treatment组的期望相减,即 E[Y|W=1] - E[Y|W=0] ;
- 另外一种是求解式(12)的局部估计等式。
在随机森林假设的线性treatment effect的情况下,这两种计算本质上是等价的。那为什么式(13)中的 \hat{\theta}_{C_2},\hat{\theta}_{C_2} 不能直接用第一种方式求,而是要大费周章地用梯度去近似呢?
目前的结论:上述等价性成立的前提是线性effect和二元treatments假设,第二种计算方式可以推广到多元甚至连续treatments。
参考资料
[1] Athey S, Imbens G. Recursive partitioning for heterogeneous causal effects[J]. Proceedings of the National Academy of Sciences, 2016, 113(27): 7353-7360.
[2] Wager S, Athey S. Estimation and inference of heterogeneous treatment effects using random forests[J]. Journal of the American Statistical Association, 2018, 113(523): 1228-1242.
[3] Athey S, Tibshirani J, Wager S. Generalized random forests[J]. The Annals of Statistics, 2019, 47(2): 1148-1178.
[4] Oprescu M, Syrgkanis V, Wu Z S. Orthogonal random forest for causal inference[C]//International Conference on Machine Learning. PMLR, 2019: 4932-4941.