Seaborn其实是在matplotlib的基础上进行了更高级的API封装,从而使得作图更加容易,在大多数情况下使用seaborn就能做出很具有吸引力的图,而使用matplotlib能制作具有更多特色的图。应该把Seaborn视为matplotlib的补充,而不是替代物。
一般来说,seaborn能满足数据分析90%的绘图需求,够用了,如果需要复杂的自定义图形,还是要matplotlit。这里也只是对seaborn官网的绘图API简单翻译整理下,对很多参数使用方法都没有说到,如果需要精细绘图,还是需要参照其seaborn的文档的。
在seaborn中图形大概分这么几类,因子变量绘图,数值变量绘图,两变量关系绘图,时间序列图,热力图,分面绘图等。
因子变量绘图
-
箱线图boxplot
-
小提琴图violinplot
-
散点图striplot
-
带分布的散点图swarmplot
-
直方图barplot
-
计数的直方图countplot
-
两变量关系图factorplot
回归图
回归图只要探讨两连续数值变量的变化趋势情况,绘制x-y的散点图和回归曲线。
-
线性回归图lmplot
-
线性回归图regplot
分布图
包括单变量核密度曲线,直方图,双变量多变量的联合直方图,和密度图
热力图
1. 热力图heatmap
聚类图
1. 聚类图clustermap
时间序列图
1. 时间序列图tsplot
2. 我的时序图plot_ts_d , plot_ts_m
分面绘图
1.分面绘图FacetGrid
import seaborn as sns
sns.set_style("whitegrid")
tips = sns.load_dataset("tips")
ax = sns.boxplot(x=tips["total_bill"])
ax = sns.boxplot(y=tips["total_bill"])
ax = sns.boxplot(x="day", y="total_bill", data=tips)
ax = sns.boxplot(x="day", y="total_bill", hue="smoker",
data=tips, palette="Set3")
ax = sns.boxplot(x="day", y="total_bill", hue="time",
data=tips, linewidth=2.5)
ax = sns.boxplot(x="time", y="tip", data=tips,
order=["Dinner", "Lunch"])
iris = sns.load_dataset("iris")
ax = sns.boxplot(data=iris, orient="h", palette="Set2")
箱线图+有分布趋势的散点图–>的组合图
ax = sns.boxplot(x="day", y="total_bill", data=tips)
ax = sns.swarmplot(x="day", y="total_bill", data=tips, color=".25")
小提琴图
其实是
箱线图
与
核密度图
的结合,箱线图展示了分位数的位置,小提琴图则展示了任意位置的密度,通过小提琴图可以知道哪些位置的密度较高。在图中,白点是中位数,黑色盒型的范围是下四分位点到上四分位点,细黑线表示须。外部形状即为核密度估计(在概率论中用来估计未知的密度函数,属于非参数检验方法之一)。
import seaborn as sns
sns.set_style("whitegrid")
tips = sns.load_dataset("tips")
ax = sns.violinplot(x=tips["total_bill"])
ax = sns.violinplot(x="day", y="total_bill", data=tips)
ax = sns.violinplot(x="day", y="total_bill", hue="smoker",
data=tips, palette="muted")
ax = sns.violinplot(x="day", y="total_bill", hue="smoker",
data=tips, palette="muted", split=True)
ax = sns.violinplot(x="time", y="tip", data=tips,
order=["Dinner", "Lunch"])
其他的样式不常用,就不贴上来了。
需要注意的是,seaborn中有两个散点图,一个是普通的散点图,另一个是可以看出分布密度的散点图。下面把它们花在一起就明白了。
ax1 = sns.stripplot(x=tips["total_bill"])
ax2 = sns.swarmplot(x=tips["total_bill"])
ax = sns.stripplot(x="day", y="total_bill", data=tips)
ax = sns.stripplot(x="day", y="total_bill", data=tips, jitter=True)
ax = sns.stripplot(x="total_bill", y="day", data=tips,jitter=True)
ax = sns.stripplot(x="sex", y="total_bill", hue="day",
data=tips, jitter=True)
ax = sns.stripplot(x="day", y="total_bill", hue="smoker",
data=tips, jitter=True,palette="Set2", split=True)
ax = sns.violinplot(x="day", y="total_bill", data=tips,inner=None, color=".8")
ax = sns.stripplot(x="day", y="total_bill", data=tips,jitter=True)
swarmplt的参数和用法和stripplot的用法是一样的,只是表现形式不一样而已。
import seaborn as sns
sns.set_style("whitegrid")
tips = sns.load_dataset("tips")
ax = sns.swarmplot(x=tips["total_bill"])
ax = sns.swarmplot(x="day", y="total_bill", data=tips)
ax = sns.boxplot(x="tip", y="day", data=tips, whis=np.inf)
ax = sns.swarmplot(x="tip", y="day", data=tips)
ax = sns.violinplot(x="day", y="total_bill", data=tips, inner=None)
ax = sns.swarmplot(x="day", y="total_bill", data=tips,
color="white", edgecolor="gray")
我不喜欢显示直方图上面的置信度线,难看,所以下面的图形我都设置ci=0.(Size of confidence intervals to draw around estimated values)
直方图的统计函数,绘制的是变量的均值 estimator=np.mean
import seaborn as sns
sns.set_style("whitegrid")
tips = sns.load_dataset("tips")
ax = sns.barplot(x="day", y="total_bill", data=tips,ci=0)
ax = sns.barplot(x="day", y="total_bill", hue="sex", data=tips,ci=0)
from numpy import median
ax = sns.barplot(x="day", y="tip", data=tips,
estimator=median, ci=0)
ax = sns.barplot("size", y="total_bill", data=tips,
palette="Blues_d")
这个很重要,对因子变量计数,然后绘制条形图
import seaborn as sns
sns.set(style="darkgrid")
titanic = sns.load_dataset("titanic")
ax = sns.countplot(x="class", data=titanic)
ax = sns.countplot(x="class", hue="who", data=titanic)
ax = sns.countplot(y="class", hue="who", data=titanic)
这是一类重要的变量联合绘图。
绘制 因子变量-数值变量 的分布情况图。
import seaborn as sns
sns.set(style="ticks")
exercise = sns.load_dataset("exercise")
g = sns.factorplot(x="time", y="pulse", hue="kind",
data=exercise, kind="violin")
titanic = sns.load_dataset("titanic")
g = sns.factorplot(x="alive", col="deck", col_wrap=4,
data=titanic[titanic.deck.notnull()],
kind="count", size=2.5, aspect=.8)
import seaborn as sns; sns.set(color_codes=True)
tips = sns.load_dataset("tips")
g = sns.lmplot(x="total_bill", y="tip", data=tips)
g = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips)
g = sns.lmplot(x="total_bill", y="tip", hue="smoker",
data=tips,markers=["o", "x"])
g = sns.lmplot(x="total_bill", y="tip", col="smoker", data=tips)
g = sns.lmplot(x="size", y="total_bill", hue="day",
col="day",data=tips, aspect=.4, x_jitter=.1)
g = sns.lmplot(x="total_bill", y="tip", col="day", hue="day",
data=tips, col_wrap=2, size=3)
g = sns.lmplot(x="total_bill", y="tip", row="sex",
col="time", data=tips, size=3)
Plot the relationship between two variables in a DataFrame:
import seaborn as sns; sns.set(color_codes=True)
tips = sns.load_dataset("tips")
ax = sns.regplot(x="total_bill", y="tip", data=tips)
import numpy as np; np.random.seed(8)
mean, cov = [4, 6], [(1.5, .7), (.7, 1)]
x, y = np.random.multivariate_normal(mean, cov, 80).T
ax = sns.regplot(x=x, y=y, color="g", marker="+")
ax = sns.regplot(x=x, y=y, ci=68)
ans = sns.load_dataset("anscombe")
ax = sns.regplot(x="x", y="y", data=ans.loc[ans.dataset == "II"],
scatter_kws={"s": 80},order=2, ci=None, truncate=True)
直方图hist=True,核密度曲线rug=True
import seaborn as sns, numpy as np
sns.set(rc={"figure.figsize": (8, 4)}); np.random.seed(0)
x = np.random.randn(100)
ax = sns.distplot(x)
ax = sns.distplot(x, rug=True, hist=False)
ax = sns.distplot(x, vertical=True)
import numpy as np; np.random.seed(10)
import seaborn as sns; sns.set(color_codes=True)
mean, cov = [0, 2], [(1, .5), (.5, 1)]
x, y = np.random.multivariate_normal(mean, cov, size=50).T
ax = sns.kdeplot(x)
ax = sns.kdeplot(x, shade=True, color="r")
ax = sns.kdeplot(x, y, shade=True)
iris = sns.load_dataset("iris")
setosa = iris.loc[iris.species == "setosa"]
virginica = iris.loc[iris.species == "virginica"]
ax = sns.kdeplot(setosa.sepal_width, setosa.sepal_length,
cmap="Reds", shade=True, shade_lowest=False)
ax = sns.kdeplot(virginica.sepal_width, virginica.sepal_length,
cmap="Blues", shade=True, shade_lowest=False)
-
1
-
2
-
3
-
4
-
5
-
6
-
7
-
8
-
9
-
10
-
11
-
12
joint,顾名思义,就是联合呀。
Draw a plot of two variables with bivariate and univariate graphs.
kind参数可以使用不同的图形反应两变量的关系,比如点图,线图,核密度图。
import numpy as np, pandas as pd; np.random.seed(0)
import seaborn as sns; sns.set(style="white", color_codes=True)
tips = sns.load_dataset("tips")
g = sns.jointplot(x="total_bill", y="tip", data=tips)
g = sns.jointplot("total_bill", "tip", data=tips, kind="reg")
g = sns.jointplot("total_bill", "tip", data=tips, kind="hex")
iris = sns.load_dataset("iris")
g = sns.jointplot("sepal_width", "petal_length", data=iris,
kind="kde", space=0, color="g")
g = sns.jointplot("total_bill", "tip", data=tips,
size=5, ratio=3, color="g")
就是绘制dataframe中各个变量两两之间的关系图。
在变量关系图中,最常见的就是 x-y的线图,x-y的散点图,x-y的回归图。其实这三者都可以通过lmplot绘制,只是控制不同的参数而已。x-y的线图,其实就是时间序列图,这里就不说了。
这里又说一遍散点图,是为了和前面的因子变量散点图相区分,前面的因子变量散点图,讲的是不同因子水平的值绘制的散点图,而这里是两个数值变量值散点图关系。为什么要用lmplot呢,说白了就是,先将这些散点画出来,然后在根据散点的分布情况拟合出一条直线。但是用lmplot总觉得不好,没有用scatter来得合适。
tips = sns.load_dataset("tips")
g = sns.lmplot(x="total_bill", y="tip", data=tips,
fit_reg=False,hue='smoker',scatter=True)
g = sns.lmplot(x="total_bill", y="tip", data=tips,
fit_reg=True,hue='smoker',scatter=False)
import seaborn as sns; sns.set(style="ticks", color_codes=True)
iris = sns.load_dataset("iris")
g = sns.pairplot(iris)
g = sns.pairplot(iris, hue="species")
g = sns.pairplot(iris, hue="species", markers=["o", "s", "D"])
g = sns.pairplot(iris, vars=["sepal_width", "sepal_length"])
g = sns.pairplot(iris, diag_kind="kde")
import numpy as np; np.random.seed(0)
import seaborn as sns; sns.set()
uniform_data = np.random.rand(10, 12)
ax = sns.heatmap(uniform_data)
ax = sns.heatmap(uniform_data, vmin=0, vmax=1)
Plot a dataframe with meaningful row and column labels:
flights = sns.load_dataset("flights")
flights = flights.pivot("month", "year", "passengers")
ax = sns.heatmap(flights)
ax = sns.heatmap(flights, annot=True, fmt="d")
data = np.random.randn(50, 20)
ax = sns.heatmap(data, xticklabels=2, yticklabels=False)
import numpy as np; np.random.seed(22)
import seaborn as sns; sns.set(color_codes=True)
x = np.linspace(0, 15, 31)
data = np.sin(x) + np.random.rand(10, 31) + np.random.randn(10, 1)
ax = sns.tsplot(data=data)
gammas = sns.load_dataset("gammas")
ax = sns.tsplot(time="timepoint", value="BOLD signal",
unit="subject", condition="ROI", data=gammas)
ax = sns.tsplot(data=data, ci=[68, 95], color="m")
ax = sns.tsplot(data=data, estimator=np.median)
这里重点讲一下。如果时序中每天的数据都有还好说,如果没有,就需要采样了。
def plot_ts_day(x,y):
"""绘制每天的时间序列图。
需要注意的是,序列是不是连续的,也就是说某天的数据是没有的,因此需要采样至每天都有记录,原来数据没有的就填充0
x:时间轴,string或者time类型,是一个seires
x=[pd.to_datetime(str(i)) for i in x]
y=[i for i in y]
s=pd.Series(y,index=x)
s = s.resample(rule='D',fill_method='ffill')
s[s.index]=0
s[x]=y
x2 = [i.strftime('%Y-%m-%d') for i in s.index]
s.index = x2
s.plot()
-
1
-
2
-
3
-
4
-
5
-
6
-
7
-
8
-
9
-
10
-
11
-
12
-
13
-
14
-
15
-
16
-
17
-
18
-
19
def plot_ts_month(x,y):
"""绘制月的时间序列图,每月一个数据点,而不是每天一个"""
try:
x = [pd.to_datetime(str(i)) for i in x]
except:
x=[pd.to_datetime(str(i)+'01') for i in x]
y=[i for i in y]
s=pd.Series(y,index=x)
s = s.resample('M', label='right').sum().fillna(0)
s.index=[i.strftime('%Y%m') for i in s.index]
s.plot()
-
1
-
2
-
3
-
4
-
5
-
6
-
7
-
8
-
9
-
10
-
11
-
12
-
13
-
14
-
15
说实话,到现在还没搞懂怎么用sns.tsplot绘制分组线图,但是任务紧急,就用pandas的dataframe自带方法plot来绘图了,其实也挺简单的。
主要注意的是,尽量给dataframe或者series建立时间索引,不然x轴很难看的。
data.index = data['year'].map(str)+data['month2'].map(lambda x: str(x) if x>=10 else '0'+str(x))
data['salecount'].plot()
分组的线图,比如seaborn中的hue参数,方法是,先将dataframe长表格式转成宽表格式(透视表),每列是不同的年。
data.pivot(index='month2',columns='year',values='salecount').plot(title='销量')
data.pivot_table(index='month2',columns='year',values='salecount',aggfunc='sum') \
.plot(title='销量',style='o-')
图形格式选项
-
1
-
2
-
3
-
4
-
5
-
6
-
7
-
8
-
9
-
10
-
11
-
12
-
13
-
14
-
15
-
16
-
17
-
18
-
19
-
20
-
21
-
22
-
23
-
24
data.pivot(index='month2',columns='year',values='salecount')\
.plot(title='销量',style='-o')