Distributed DataLoader For Pytorch Based On Ray




随着GPU与CPU的算力差距越来越大以及模型训练时的预处理Pipeline变得越来越复杂,CPU部分的数据预处理已经逐渐成为了模型训练的瓶颈所在,这导致单机的GPU配置的提升并不能带来期望的线性加速。预处理性能瓶颈的本质在于每个GPU能够使用的CPU算力受限, 为了解决这个问题NVIDIA提出了scale up的方案——GPU数据预处理库DALI,Tensorflow给出了scale out的方案——分布式数据预处理组件DataService,而在这里我们给出Pytorch生态中的scale out方案——分布式数据预处理组件Dpex。

二、架构介绍(介绍Pytorch DataLoader本身的架构以及DistDataLoader的架构)



不仅在设计上,Dpex的实现上也完全兼容Pytorch的DataLoader。当并行数据预处理时,若设置distribute_modeTrueDpexDataLoader使用 _RayDataLoaderIter实现分布式数据预处理,当设置为FalseDpexDataLoader退回到使用Pytorch本身的_MultiProcessingDataLoaderIter 实现并行数据预处理与加载。在Pytorch训练中使用Dpex非常的简单,只需要将使用到Pytorch的DataLoader的地方替换为Dpex中的DpexDataLoader即可,当你的训练机器本身为Ray集群中的一个节点时,设置 distribute_mode=True可以启用分布式数据预处理。在下面我们给出单卡训练,使用DataParallel进行多卡训练以及使用DDP进行多卡训练时使用Dpex的示例,具体可参考测试文件。
class DpexDataLoader(torch.utils.data.DataLoader): def init(self, dataset: Dataset[T_co], distribute_mode: Optional[bool] = False, head_address="auto", batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[Sampler[int]] = None, batch_sampler: Optional[Sampler[Sequence[int]]] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2):

3.1 单卡训练


from torchvision import datasets
from torchvision.transforms import ToTensor
from Dpex import dataloader

training_data = datasets.FashionMNIST(
# use DpexDataLoader
train_loader = dataloader.DpexDataLoader(training_data, distribute_mode=True, num_workers=10, batch_size=100, shuffle=True)

for epoch in range(3):
    for index, (image, label) in enumerate(train_loader):
       # your train process

3.2 基于DataParallel的多卡训练


import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset
from Dpex import dataloader

class MyOwnDataset(Dataset):
# use DpexDataLoader
data_loader = dataloader.DpexDataLoader(dataset=RandomDataset(input_size, data_size),
                                        distribute_mode=True, batch_size=batch_size, shuffle=True, num_workers=10)

class Model(nn.Module):
model = Model()

if torch.cuda.is_available():

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

for data in data_loader:
   # train your own model

3.3 基于DDP的多卡训练


import torch
import torch.nn as nn
from torch.utils.data import Dataset
from Dpex.dataloader import DpexDataLoader
from torch.utils.data.distributed import DistributedSampler

# start command: CUDA_VISIBLE_DEVICES=1,6,7 python -m torch.distributed.launch --nproc_per_node=2 pytorch_ddp.py
# 1) 初始化

input_size = 5
output_size = 2
batch_size = 1
data_size = 90000

# 2) 配置每个进程的gpu
local_rank = torch.distributed.get_rank()
device = torch.device("cuda", local_rank)

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

dataset = RandomDataset(input_size, data_size)
# 3)使用DistributedSampler
rand_loader = DpexDataLoader(dataset=dataset, distribute_mode=True, batch_size=batch_size, sampler=DistributedSampler(dataset), num_workers=10)

class Model(nn.Module):
    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input):
        output = self.fc(input)
        print("  In Model: input size", input.size(),
              "output size", output.size())
        return output

model = Model(input_size, output_size)

# 4) 封装之前要把模型移到对应的gpu

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # 5) 封装
    model = torch.nn.parallel.DistributedDataParallel(model,

for data in rand_loader:
    if torch.cuda.is_available():
        input_var = data
        input_var = data

    output = model(input_var)
    print("Outside: input size", input_var.size(), "output_size", output.size())



  • DpexDataLoader对于模型训练精度的影响
  • DpexDataLoader对于模型训练速度的影响


4.1 模型精度Benchmark


Accuracy(%) Loss GPU Settings DpexDataLoader(distributed_mode=?) Epoch Learning rate Batch size
90.65 0.137 Single GPU True 40 0.001 100
91.09 0.112 Single GPU False 40 0.001 100
90.67 0.016 DataParallel True 40 0.001 100
90.32 0.008 DataParallel False 40 0.001 100
88.98 0.034 DDP True 40 0.001 100
89.84 0.030 DDP False 40 0.001 100

4.2 训练速度Benchmark



  • worker_loop看起来是读取了本地的数据,然后发送给了remote node。导致效率不高

    worker_loop看起来是读取了本地的数据,然后发送给了remote node。导致效率不高


    使用fashion_mnist_train_test.py。 FashionMNIST预先下载到本地(download=False),remote node没有下载,但是程序依然可以运行。

    猜测: worker_loop读取了本地训练数据,然后发送给了remote node? 对ray不是很熟悉,只是猜测。如果真是这样,这样的效率其实很低。 更需要的模式是:

    1. remote node 读取其本地(或者数据库)数据,预处理好,发回driver。
    opened by umialpha 2
  • tests的代码如何跑起来?




    raceback (most recent call last):
      File "tests/fashion_mnist_train_test.py", line 85, in <module>
        for images, labels in train_loader:
      File "/home/ray/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 354, in __iter__
        self._iterator = self._get_iterator()
      File "/home/ray/anaconda3/lib/python3.7/site-packages/Dpex/dataloader.py", line 37, in _get_iterator
        return _RayDataLoaderIter(self)
      File "/home/ray/anaconda3/lib/python3.7/site-packages/Dpex/dataloader.py", line 67, in __init__
        self._drop_last, self._base_seed, self._worker_init_fn, i)
      File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/remote_function.py", line 123, in _remote_proxy
        return self._remote(args=args, kwargs=kwargs)
      File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/util/tracing/tracing_helper.py", line 293, in _invocation_remote_span
        return method(self, args, kwargs, *_args, **_kwargs)
      File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/remote_function.py", line 233, in _remote
      File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/_private/client_mode_hook.py", line 145, in client_mode_convert_function
        return client_func._remote(in_args, in_kwargs, **kwargs)
      File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/util/client/common.py", line 130, in _remote
        return self.options(**option_args).remote(*args, **kwargs)
      File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/util/client/common.py", line 380, in remote
        return return_refs(ray.call_remote(self, *args, **kwargs))
      File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/util/client/api.py", line 106, in call_remote
        return self.worker.call_remote(instance, *args, **kwargs)
      File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/util/client/worker.py", line 452, in call_remote
        pb_arg = convert_to_arg(arg, self._client_id)
      File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/util/client/client_pickler.py", line 174, in convert_to_arg
        out.data = dumps_from_client(val, client_id)
      File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/util/client/client_pickler.py", line 154, in dumps_from_client
      File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/cloudpickle/cloudpickle_fast.py", line 580, in dump
        return Pickler.dump(self, obj)
      File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/util/client/client_pickler.py", line 86, in persistent_id
      File "python/ray/includes/unique_ids.pxi", line 369, in ray._raylet.ClientActorRef.id.__get__
      File "python/ray/includes/unique_ids.pxi", line 348, in ray._raylet.ClientActorRef.binary
      File "python/ray/includes/unique_ids.pxi", line 378, in ray._raylet.ClientActorRef._wait_for_id
      File "python/ray/includes/unique_ids.pxi", line 380, in ray._raylet.ClientActorRef._wait_for_id
      File "/home/ray/anaconda3/lib/python3.7/concurrent/futures/_base.py", line 428, in result
        return self.__get_result()
      File "/home/ray/anaconda3/lib/python3.7/concurrent/futures/_base.py", line 384, in __get_result
        raise self._exception
    opened by umialpha 1
