数据分析利器 pandas 系列教程(六):合并上百万个 csv 文件,如何提速上百倍
如开篇初衷,这个系列教程对于基础知识的引导,不求细致而大全,但求细致而实用,
过完基础知识以后就是实战 tricks 的集锦,这些都是笔者在实际工作中用到的解决方案,求小而精,抛砖引玉。
所以后续的更新本来就应该是可遇不可求的,但是我不想以此作为拖更的借口,因为事实上,这大半年我是一直有更新的。
这一年半在我的 BuyiXiao Blog 上更新了差不多 10 篇(标签是 pandas,地址如下),但是几乎都没有发布在公众号上。
https://buyixiao.github.io/tags/pandas/
还是那个原因, 代码工程永远是追求最佳实践的,或者更准确的来说应该是更佳实践,因为我觉得脱离了时间背景,没有最佳实践。
所以即使是一个讲解功能点的教程,需要频繁地对一篇教程进行反复修改,不然就是以讹传讹了,公众号只能修改一次太差强人意,所以就都发布在博客上,不定期搬运到公众号上。
所以可以把上面这个链接加入收藏夹吗?
回到今天的正题,加速 pandas 合并 csv ~
在上一篇的教程 数据分析利器 pandas 系列教程(五):合并相同结构的 csv 分享了合并的思路和代码,
# -*- coding: utf-8 -*-
# author: inspurer(月小水长)
# create_time: 2022/4/13 10:33
# 运行环境 Python3.6+
# github https://github.com/inspurer
# website https://buyixiao.github.io/
# 微信公众号 月小水长
import os
import pandas as pd
result_csv = 'all.csv'
all_cols = []
for file in os.listdir('.'):
if file.endswith('.csv') and not file == result_csv:
df = pd.read_csv(file)
all_cols = df.columns.values.tolist()
if len(all_cols) == 0:
raise Exception("当前目录下没有要合并的 csv 文件")
all_cols.insert(0, 'origin_file_name')
all_df = pd.DataFrame({col: [] for col in all_cols})
for file in os.listdir('.'):
if file.endswith('.csv') and not file == result_csv:
df = pd.read_csv(file)
df.insert(0, 'origin_file_name', [file for _ in range(df.shape[0])])
all_df = all_df.append(df, ignore_index=True)
all_df.to_csv(result_csv, index=False, encoding='utf-8')
但是最近我遇到一个工程问题, 需要合并超过 1000,000 (上百万)个 csv 文件,最大的 10M 左右,最小的 5KB 左右,最开始用的上面这现成的代码,运行了一天之后,我觉得照目前这速度,差不多得合并到元旦 。
所以探索更佳实践使得我逐行分析了代码耗时,发现大量或者说 99.99% 的耗时集中在下面这行代码上:
all_df = all_df.append(df, ignore_index=True)
pandas 官方已经不推荐使用 append 来连接
dataframe
了,转而使用concat
,即all_df = pd.concat([all_df,df], ignore_index=True)
但是这不是今天讨论的重点
最开始我为什么要设计成 for 循环中读一个 csv 就合并一次呢,因为我觉得读取全部文件到内存中再合并非常吃内存,设计成这样保存每次只有一个两个
dataframe
即
df
和
all_df
驻留在内存中。
最开始几百个几千个文件合并的时候这份代码运行没有问题,时间也非常短,但是几十上百万个文件合并时,问题就暴露出来了。
问题在于,
append
或者
concat
每执行一次,都需要复制一份当前结果
dataframe
的副本,上百个文件复制尚可,上百万个文件,到后面每复制一次当前已合并的结果
dataframe
,耗时可想而知。
找到问题所在,解决办法就很简单了, 把 pandas 的连接放到 for 循环外只集中连接一次即可,这就意味着,需要加载完所有的 csv 文件后再连接,改良后合并原来那些上百万个 csv 文件只用不到一个下午 , 测算过耗时减少超过 99% 。
concat
中有非常多的耗时处理,复制副本仅是比较重要其中一项,这里仅以复制代指这些过程。
定量分析下,假设合并第一个 csv 文件时耗时 1 个时间单位,合并第 N 个 csv 文件时耗时 N 个单位(第一次复制时只合并了 1 个 csv,第 N 次复制时已合并 N 个 csv,假定所有文件大小相同,
concat
耗时仅和复制有关,复制仅和文件大小线性相关),
那么执行 N 次合并耗时
1+2+3+4+...+N-1+N = (N-1)*N/2
个时间单位;如果把连接放在 for 循环外,则只需要第 N 次的耗时 N 个时间单位即可,也就是说,改进后耗时仅是原来的
(N-1)*N/(2*N)=(N-1)/2
分之一,仅和文件总数 N 相关。
按照上面的分析,待合并的 csv 文件夹越多,也就是 N 越大,相比较把连接放在 for 循环,只连接一次的耗时减少得越多(N 很小的时候减少不明显),代码如下:
# -*- coding: utf-8 -*-
# author: inspurer(月小水长)
# create_time: 2023/10/30 15:23
# 运行环境 Python3.6+
# github https://github.com/inspurer
# website https://buyixiao.github.io/
# 微信公众号 月小水长
import pandas as pd
import os
def do_merge(input_folder, output_file='all.csv', append_file_name_col=True, file_name_col='origin_file_name'):
result_csv = output_file
all_cols = []
if not os.path.exists(input_folder):
raise Exception(f"目录 {input_folder} 不存在")
file_cnt = len(os.listdir(input_folder))
for file in os.listdir(input_folder):
if file.endswith('.csv') and not file == result_csv:
df = pd.read_csv(os.path.join(input_folder, file))
all_cols = df.columns.values.tolist()
break
if len(all_cols) == 0:
raise Exception(f"当前目录 {os.path.abspath(input_folder)}下没有要合并的 csv 文件")
save_cols = all_cols
if append_file_name_col:
all_cols.insert(0, file_name_col)
save_cols.insert(0, file_name_col)
df_list = []
for index, file in enumerate(os.listdir(input_folder)):
print(f'{index + 1}/ {file_cnt} {file}')
if file.endswith('.csv') and not file == result_csv:
file_name = file[:file.rindex('.')]
df = pd.read_csv(os.path.join(input_folder, file), float_precision='high')
if append_file_name_col:
df.insert(0, file_name_col, [file_name for _ in range(df.shape[0])])
df = df[save_cols]
df_list.append(df)
all_df = pd.concat(df_list, ignore_index=True)
print(all_df.shape[0])
# subset_ = ['unique id colums name of your dataframe']
subset_ = []
if append_file_name_col:
subset_.append(file_name_col)
all_df.drop_duplicates(subset=subset_, inplace=True, keep='first')
print(all_df.shape[0])
all_df.to_csv(result_csv, index=False, encoding='utf-8-sig')