Seaborn Jointplot为每个类别添加颜色

7 人关注

我想用seaborn绘制2个变量的相关图 jointplot 。我已经尝试了很多不同的方法,但我无法根据类别为各点添加颜色。

Here is my code:

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
sns.set()
X = np.array([5.2945 , 3.6013 , 3.9675 , 5.1602 , 4.1903 , 4.4995 , 4.5234 ,
              4.6618 , 0.76131, 0.42036, 0.71092, 0.60899, 0.66451, 0.55388,
              0.63863, 0.62504, 0.     , 0.     , 0.49364, 0.44828, 0.43066,
              0.57368, 0.     , 0.     , 0.64824, 0.65166, 0.64968, 0.     ,
              0.     , 0.52522, 0.58259, 1.1309 , 0.     , 0.     , 1.0514 ,
              0.7519 , 0.78745, 0.94873, 1.0169 , 0.     , 0.     , 1.0416 ,
              0.     , 0.     , 0.93648, 0.92801, 0.     , 0.     , 0.89594,
              0.     , 0.80455, 1.0103 ])
y = np.array([ 93, 115, 107, 115, 110, 107, 102, 113,  95, 101, 116,  74, 102,
               102,  78,  85, 108, 110, 109,  80,  91,  88,  99, 110, 108,  96,
               105,  93, 107,  98,  88,  75, 106,  92,  82,  84,  84,  92, 115,
               107,  97, 115,  85, 133, 100,  65,  96, 105, 112, 107, 107, 105])
ax = sns.jointplot(X, y, kind='reg' )
ax.set_axis_labels(xlabel='Brain scores', ylabel='Cognitive scores')
plt.tight_layout()
plt.show()

现在,我想根据一个类变量classes为每个点添加颜色。

2 个评论
请不要在问题中回答自己的问题。如果你认为现有的答案没有回答问题,并认为你有更好的/不同的解决方案,请提供该解决方案作为你问题的答案。
我同意。我只是按照你的建议做了。
python
matplotlib
seaborn
scatter-plot
seralouk
seralouk
发布于 2018-07-06
4 个回答
ImportanceOfBeingErnest
ImportanceOfBeingErnest
发布于 2021-04-27
已采纳
0 人赞同

明显的解决方案是让 regplot 只画回归线,而不画点,并通过通常的散点图添加这些点,散点图有颜色 c 参数。

g = sns.jointplot(X, y, kind='reg', scatter = False )
g.ax_joint.scatter(X,y, c=classes)
    
根据你的想法,我成功地解决了我的问题。现在,在最后一步,有什么方法可以为颜色编码添加一个图例吗?
这是一个经常被问到的问题,因为目前还没有直接和简单的解决方案来增加散点的图例。可能的方法是,例如 this one , this one , or this one .Matplotlib的未来版本 可能包含一个更好的解决方案 I designed.
谢谢你提供的参考资料
seralouk
seralouk
发布于 2021-04-27
0 人赞同

我成功地找到了一个解决方案,正是我所需要的。感谢@ImportanceOfBeingErnest给了我一个想法,让 regplot 只画回归线。

Solution:

import pandas as pd
classes = np.array([1., 1., 1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 2., 2.,
                    2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 
                    2., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 
                    3., 3., 3., 3., 3., 3., 3.])
df = pd.DataFrame(map(list, zip(*[X.T, y.ravel().T])))
df = df.reset_index()
df['index'] = classes[:]
g = sns.jointplot(X, y, kind='reg', scatter = False )
for i, subdata in df.groupby("index"):
    sns.kdeplot(subdata.iloc[:,1], ax=g.ax_marg_x, legend=False)
    sns.kdeplot(subdata.iloc[:,2], ax=g.ax_marg_y, vertical=True, legend=False)
    g.ax_joint.plot(subdata.iloc[:,1], subdata.iloc[:,2], "o", ms = 8)
plt.tight_layout()
plt.show()
    
Thomas Matthew
Thomas Matthew
发布于 2021-04-27
0 人赞同

以欧内斯特的回答为基础。

sns.jointplot 中设置 scatter = False 后,使用 sns.scatterplot 构建散点图,参数 hue = classes 等于分类变量阵列。 我发现将你的数据合并到一个带有 x y classes 列的pandas数据框架中,并将其作为散点图的 data 是最干净的,但你不一定要这样做...

classes = np.array([1., 1., 1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 2., 2.,
                    2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 
                    2., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 
                    3., 3., 3., 3., 3., 3., 3.])
# make them look a little more 'categorical'
classes = classes.astype('int')
x = np.array([5.2945 , 3.6013 , 3.9675 , 5.1602 , 4.1903 , 4.4995 , 4.5234 ,
              4.6618 , 0.76131, 0.42036, 0.71092, 0.60899, 0.66451, 0.55388,
              0.63863, 0.62504, 0.     , 0.     , 0.49364, 0.44828, 0.43066,
              0.57368, 0.     , 0.     , 0.64824, 0.65166, 0.64968, 0.     ,
              0.     , 0.52522, 0.58259, 1.1309 , 0.     , 0.     , 1.0514 ,
              0.7519 , 0.78745, 0.94873, 1.0169 , 0.     , 0.     , 1.0416 ,
              0.     , 0.     , 0.93648, 0.92801, 0.     , 0.     , 0.89594,
              0.     , 0.80455, 1.0103 ])
y = np.array([ 93, 115, 107, 115, 110, 107, 102, 113,  95, 101, 116,  74, 102,
               102,  78,  85, 108, 110, 109,  80,  91,  88,  99, 110, 108,  96,
               105,  93, 107,  98,  88,  75, 106,  92,  82,  84,  84,  92, 115,
               107,  97, 115,  85, 133, 100,  65,  96, 105, 112, 107, 107, 105])
sns.jointplot(x, y, kind='reg', scatter = False )
sns.scatterplot(x, y, hue=classes)
    
slavny_coder
slavny_coder
发布于 2021-04-27
0 人赞同
       label  Method 2  Method 1
0    Label 2  1.484914 -1.069439
1    Label 2  0.273158  1.139414
2    Label 2  1.089244  0.161752
3    Label 2  1.184306 -0.981758
4    Label 2  1.424435  0.300742
..       ...       ...       ...
111  Label 2 -0.201226  0.852319
112  Label 2  0.016911  0.985805
113  Label 2 -0.263775  0.248942
114  Label 2  3.283341 -1.247014
115  Label 2  0.325648  1.793694
[116 rows x 3 columns]
sns.jointplot(data=data, x="Method 1, y="Method 2", "hue="label", palette={
    'Label 1': '#d7191c',
    'Label 2': '#2b83ba'