相关文章推荐
激动的充值卡  ·  alpine linux ...·  1 年前    · 
活泼的青蛙  ·  mysql中SELECT INTO 和 ...·  1 年前    · 
要出家的煎饼果子  ·  使用SPSS ...·  1 年前    · 

Webdataset 加速深度学习数据加载

导言 :在大规模数据上进行深度学习通常会因为IO瓶颈而拖慢训练的速度,本文介绍了webdataset是如何在深度学习中加速大规模数据加载的。

webdataset 简介

webdataset是什么 webdataset 是一个数据加载的库,其可以从tar文件中直接读取数据样本而无需将tar包中的所有文件释放出来。从某个角度看,webdataset制定了一种基于tar包的大规模数据格式,其实就是翻版的tfrecord,只不过tfrecord是google专门搞出来的格式,而webdataset直接使用tar这种通用的数据格式,没有自己另外再搞一种二进制格式。此外,webdataset是专门为PyTorch写的,可以很容易集成到已有的PyTorch代码中(其实稍微改改应该很容易集成到任何深度学习框架中)。webdataset的主要目的是为了解决传统数据加载方式(就是直接从磁盘中加载大量数据集文件)存在的一些问题。

传统数据加载方式有什么问题? 当今的大规模数据集包含了大量的数据样本,例如ImageNet包括约130万图片,OpenImage包括约900万图片,这还只是开胃菜,在大公司里面还有比这些大得多的数据集。如果这些图片样本直接存放在文件系统/对象存储系统上,数据读取会给这些系统带来极大的压力。原因包括以下几点:

  • 在深度学习训练过程中,我们需要打乱样本顺序,这要求我们能够随机访问任意一个样本。随机访问文件效率一般十分低下,无论是在文件系统/对象存储系统上,都要花费大量代价寻找文件所在的位置(因为所有的硬盘都是随机读写性能低于顺序读写性能,因此随机读取无法最大程度地利用硬盘的性能)。
  • 在训练过程中需要打开和关闭大量文件,这其中的系统调用的时间也是不能忽略的。此外,在分布式文件系统中(例如ceph),读取文件时还需要读取其对应的元数据,而这些元数据在训练过程中毫无意义,只会降低读取效率。
  • 过小的文件还会浪费磁盘空间,因为一般磁盘存储都会将文件对齐到一个大小上(例如4KB)。
  • 如果是网络文件系统或者对象存储系统,大量小文件将无法充分利用网络传输带宽,降低了传输效率。
  • 这个严格来说不算是数据加载问题:大量文件不利于数据集分发,所以在数据集分发时需要将数据样本文件打包,而使用的时候又要将其解包(在数据集规模很大的时候这一步骤也非常耗时)。

webdataset是如何解决上述问题的 :webdataset将数据样本文件打包,但是这里注意不是将所有文件打成一个特别大的包,而是将其打成若干个包。以ImageNet为例,我们可以将130万个文件打包为256个tar包,平均每个tar文件包含5k个样本。在读取的时候,webdataset将这256个tar包顺序打乱,然后按照打乱的顺序依次读取tar包。在读取每一个tar包的时候,里面存储的样本将会被顺序读取(因此很快),但是这样的话达不到打乱整个数据集的目的。因此webdataset维护了一个buffer,新读取的样本将会和buffer中的一个随机样本交换,达到打乱数据集的目的:

# read sample from the given tar file
k = rng.randint(0, len(buf) - 1)
sample, buf[k] = buf[k], sample
# return sample here

这样webdataset将把上述传统数据加载方提到的缺点都解决了,需要注意的是 webdataset的数据集打乱程度是和这个buffer的大小有关系,在实际中需要设置一个足够大的数值 。其实webdataset的工作原理和tfrecord是一模一样的,用tensorflow的同学应该是很容易理解。

性能对比:webdataset vs. 原生数据加载

为了突显webdataset的优秀特性,我们将imagenet打包为webdataset支持的一系列tar包并比较使用webdataset加载和使用pytorch原生的ImageFolder加载的速度。注意这里为了更好的对比IO速度,我们把图片文件的所有字节加载到内存中就够了,并没有进行图片解码和任何的预处理操作。下面我们分别给出了在机械硬盘和固态硬盘上用webdataset和原生数据加载方式的速度对比。

机械硬盘对比结果 :在机械硬盘上,webdataset基本上带来了 10倍 的读取速度提升。如此巨大的性能提升是因为机械硬盘的顺序读取速度比随机读取快太多了,而webdataset这个库很好地利用了这一点,几乎把所有的文件读取都变成了顺序读取。从每秒加载的图片文件大小来看,webdataset已经非常接近这块机械硬盘的读取上限(~170MB/s),基本做到了极致。

每秒加载的图片数量对比:

线程数 1 2 4 6 8
原生加载 83.20 86.19 104.40 112.96 120.42
wds加载 1447.39 1423.57 1215.70 1160.79 1020.50

每秒加载的图片文件大小 (MB/s) 对比:

线程数 1 2 4 6 8
原生加载 9.17 9.36 11.48 12.35 13.31
wds加载 159.77 155.47 134.51 125.99 112.43

固态硬盘对比结果 :在固态硬盘上,webdataset带来了从27%到56%不等的读取速度提升,这个提升远没有机械硬盘来的惊艳,但是有提升总好过没有是不是。提升比较小地原因时固态硬盘的随机读写性能相对于机械硬盘已经好了太多太多(见附录中的硬盘读写性能测试)。另外要说一点这块是SATA的固态硬盘,如果是NVME的固态硬盘,这一差距还会继续的缩小。

每秒加载的图片数量对比:

线程数 1 2 4 6 8
原生加载 1936.34 2339.95 3299.90 3515.42 3536.51
wds加载 2567.04 3665.97 4383.44 4539.08 4503.89

每秒加载的图片文件大小 (MB/s) 对比:

线程数 1 2 4 6 8
原生加载 153.54 255.70 361.70 385.44 387.74
wds加载 281.63 403.49 482.24 496.40 495.29

结论 :在机械硬盘上强烈推荐使用webdataset作为数据加载方式,在固态硬盘上也十分推荐(其实固态硬盘的原生ImageFolder的加载速度已经非常够用了)。此外,上面我们讨论的主要是本地加载数据的情况,如果在云上进行机器学习模型训练,数据文件往往会直接从分布式文件系统或者对象存储上进行读取。如果数据集文件过多也会导致分布式文件/对象存储处理过多无用的元数据,并且小文件过多也无法一直保持网络带宽的最大化利用,这些问题都会导致数据加载变成训练过程中的瓶颈,而webdataset也能很好处理这个场景(事实上,这个库就是为这类场景发明的,所以叫做webdataset)。

附录

硬件环境 :这里列出本次测试的硬件环境:

硬件名称 具体型号
CPU Intel(R) Core(TM) i5-8500 CPU @ 3.00GHz
内存 Kingston 2666MHz 8G x 2
机械硬盘 Western Digital Blue 1T 7200 rpm
固态硬盘 Intel 545s Series 256G

本次对比测试将会在机械硬盘和固态硬盘上进行,下面给出一些 fio脚本 测试得到的数据,以便更好对后面的实验对比结果进行分析:

  • 固态硬盘:顺序读 ~500MB, 随机读 ~300M
|Name          |  Read(MB/s)| Write(MB/s)|
|--------------|------------|------------|
|  SEQ1M Q1 T1 |     484.345|     352.739|
|  SEQ1M Q8 T1 |     516.135|     425.775|
| RND4K Q32T16 |     311.884|     282.172|
| . IOPS       |   76143.543|   68889.666|
| . latency us |       6.709|       7.422|
| RND4K Q1 T1  |      42.470|     106.964|
| . IOPS       |   10368.547|   26114.252|
| . latency us |       0.096|       0.037|
  • 机械硬盘:顺序读 ~175MB, 随机读 ~2.3M (太拉跨了)
|Name          |  Read(MB/s)| Write(MB/s)|
|--------------|------------|------------|
|  SEQ1M Q1 T1 |     179.546|      79.943|
|  SEQ1M Q8 T1 |     172.259|      79.952|
| RND4K Q32T16 |       2.335|       0.929|
| . IOPS       |     569.994|     226.896|
| . latency us |     823.919|    1837.834|
| RND4K Q1 T1  |       0.686|       1.094|
| . IOPS       |     167.532|     267.147|
| . latency us |       5.962|       3.734|

测试说明 :为了测试的公平性,每一次测试之前都会使用命令 sync; echo 3 > /proc/sys/vm/drop_caches 清空所有缓冲区,包括页面缓存,目录项和inode以保证数据确实是从硬盘加载而不是来自于内存缓存。

webdataset 构建 :我们使用下面的代码将imagenet的train部分打包为tar包,这里的打包代码是自己写的,用了多进程,这个库给的代码是单进程的,慢得离谱。。。

import os
import random
import datetime
from multiprocessing import Process
from torchvision import datasets
from torchvision.datasets import ImageNet
from torchvision.datasets.folder import ImageFolder
from webdataset import TarWriter
def make_wds_shards(pattern, num_shards, num_workers, samples, map_func, **kwargs):
    random.shuffle(samples)
    samples_per_shards = [samples[i::num_shards] for i in range(num_shards)]
    shard_ids = list(range(num_shards))
    processes = [
        Process(
            target=write_partial_samples,
            args=(
                pattern,
                shard_ids[i::num_workers],
                samples_per_shards[i::num_workers],
                map_func,
                kwargs
        for i in range(num_workers)]
    for p in processes:
        p.start()
    for p in processes:
        p.join()
def write_partial_samples(pattern, shard_ids, samples, map_func, kwargs):
    for shard_id, samples in zip(shard_ids, samples):
        write_samples_into_single_shard(pattern, shard_id, samples, map_func, kwargs)
def write_samples_into_single_shard(pattern, shard_id, samples, map_func, kwargs):
    fname = pattern % shard_id
    print(f"[{datetime.datetime.now()}] start to write samples to shard {fname}")
    stream = TarWriter(fname, **kwargs)
    size = 0
    for item in samples:
        size += stream.write(map_func(item))
    stream.close()
    print(f"[{datetime.datetime.now()}] complete to write samples to shard {fname}")
    return size
if __name__ == "__main__":
    root = "/gdata/ImageNet2012/train"
    items = []
    dataset = ImageFolder(root=root,  loader=lambda x:x)
    for i in range(len(dataset)):
        items.append(dataset[i])
    print(dataset[0],os.path.splitext(os.path.basename(dataset[0][0]))[0])
    def map_func(item):
        name, class_idx = item
        with open(os.path.join(name), "rb") as stream:
            image = stream.read()
        sample = {
            "__key__": os.path.splitext(os.path.basename(name))[0],
            "jpg": image,
            "cls": str(class_idx).encode("ascii")
        return sample
    make_wds_shards(
        pattern="/userhome/tars/imagenet-1k-%06d.tar",
        num_shards=256, # 设置分片数量
        num_workers=8, # 设置创建wds数据集的进程数
        samples=items,
        map_func=map_func,
    )

测试代码 :我们测试随机从数据集中读取N张图的耗时(在机械硬盘上N=30000,在固态硬盘上N=300000),根据耗时计算每秒读取图片的数量和吞吐量

import os
import time
import torch
import webdataset as wds
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
def get_ori_loader(disk, num_workers):
    def read_bytes(path):
        with open(path, "rb") as f:
            return f.read()
    root = "/mnt/extend/imagenet/train" if disk == "hdd" else "/home/chenyaofo/webdataset-test/train"
    dataset = ImageFolder(root,  loader=read_bytes, is_valid_file=lambda x: True)
    dataloader = DataLoader(dataset, num_workers=num_workers, shuffle=True, batch_size=128)
    return dataloader
def get_wds_loader(disk, num_workers):
    url = "/mnt/extend/tars/imagenet-1k-{000000..000256}.tar" if disk == "hdd" else "/home/chenyaofo/webdataset-test/tars/imagenet-1k-{000000..000256}.tar"
    def my_decoder(key, value):
        if not key.endswith(".jpg"):
            return None
        assert isinstance(value, bytes)
        return value
    dataset = wds.WebDataset(url).shuffle(1000).decode(my_decoder)
    dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=128)
    return dataloader
def run_test(loader, disk):
    N_stop = 30000 if disk == "hdd" else 300000
    start = time.perf_counter()
    total_batch_size = 0
    total_bytes = 0
    for items in loader:
        if isinstance(items, dict):
            batch_size = len(items['jpeg.cls'])
            n_bytes = sum(map(lambda x: len(x), items['jpeg.jpg']))
        else:
            batch_size = len(items[1])
            n_bytes = sum(map(lambda x: len(x), items[0]))
        total_batch_size += batch_size
        # print(total_batch_size)
        total_bytes += n_bytes
        if total_batch_size > N_stop:
            end = time.perf_counter()
            return total_batch_size, total_bytes, end-start
for disk in ["ssd", "hdd"]: