First of all, thank you for this library!
Description of the bug
When training with Distributed Data Parallel (DDP), the gradient between different devices is not correctly synchronized when using RiemannianSGD (or RiemannianAdam). Replacing it with a standard torch.optim.SGD
works well. Note that when using DDP the gradient is synchronized during .backprop()
(see this link).
To Reproduce
Simple code training on ImageNet:
import os
import geoopt
import torch
import torch.distributed
import torch.multiprocessing as mp
import torchvision
import torchvision.models as models
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils import data
from torchvision import transforms
def process_ddp(master_port, local_rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(master_port)
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group("nccl", rank=local_rank, world_size=world_size, init_method='env://')
assert world_size == torch.distributed.get_world_size()
return device
def main(local_rank, world_size):
path_dataset = '/path/to/ImageNet' # Any other dataset should result in a similar behavior
master_port = 9999
device = process_ddp(master_port, local_rank, world_size)
model = models.resnet18()
model = model.to(device)
# optimizer = geoopt.optim.RiemannianSGD(model.parameters(), lr=0.1, stabilize=10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# Data parallelization
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
# Prepare dataset
transform = transforms.Compose([
transforms.CenterCrop(size=256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = torchvision.datasets.ImageNet(split='train', root=path_dataset, transform=transform)
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, sampler=sampler, shuffle=False, num_workers=8)
# Train part of the first epoch
model.train()
for idx, (images, labels) in enumerate(data_loader):
if idx >= 10:
break
images = images.to(device)
labels = labels.to(device)
with torch.set_grad_enabled(True):
features = model(images)
loss = torch.nn.functional.cross_entropy(features, labels)
loss.backward()
print(f'grad iteration {idx} on gpu {device}: {model.module.conv1.weight.grad.mean()}', flush=True)
optimizer.step()
optimizer.zero_grad()
print(f'weight iteration {idx} on gpu {device}: {model.module.conv1.weight.mean()}', flush=True)
# cleanup
torch.distributed.destroy_process_group()
if __name__ == '__main__':
world_size_main = torch.cuda.device_count()
mp.spawn(main,
args=(world_size_main,),
nprocs=world_size_main,
join=True)
In order to run, use:
CUDA_VISIBLE_DEVICES=0,1 python run.py
Expected behavior
The expected behavior is the one that occurs when the line optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
is uncommented, and the line optimizer = geoopt.optim.RiemannianSGD(model.parameters(), lr=0.1, , stabilize=10)
is commented. In that case, the output is:
grad iteration 0 on gpu cuda:1: 0.011769304051995277
grad iteration 0 on gpu cuda:0: 0.011769304051995277
weight iteration 0 on gpu cuda:1: -0.001249525579623878
weight iteration 0 on gpu cuda:0: -0.001249525579623878
grad iteration 1 on gpu cuda:0: -0.015764284878969193
grad iteration 1 on gpu cuda:1: -0.015764284878969193
weight iteration 1 on gpu cuda:1: 0.0003269027511123568
weight iteration 1 on gpu cuda:0: 0.0003269027511123568
grad iteration 2 on gpu cuda:1: -0.006310341879725456
grad iteration 2 on gpu cuda:0: -0.006310341879725456
weight iteration 2 on gpu cuda:1: 0.000957937038037926
weight iteration 2 on gpu cuda:0: 0.000957937038037926
grad iteration 3 on gpu cuda:0: 0.0021547293290495872
grad iteration 3 on gpu cuda:1: 0.0021547293290495872
weight iteration 3 on gpu cuda:1: 0.000742464151699096
weight iteration 3 on gpu cuda:0: 0.000742464151699096
grad iteration 4 on gpu cuda:1: -0.002606849418953061
grad iteration 4 on gpu cuda:0: -0.002606849418953061
weight iteration 4 on gpu cuda:1: 0.001003148965537548
weight iteration 4 on gpu cuda:0: 0.001003148965537548
grad iteration 5 on gpu cuda:1: 0.00043087091762572527
grad iteration 5 on gpu cuda:0: 0.00043087091762572527
weight iteration 5 on gpu cuda:1: 0.0009600619086995721
weight iteration 5 on gpu cuda:0: 0.0009600619086995721
grad iteration 6 on gpu cuda:0: 0.00014396056940313429
grad iteration 6 on gpu cuda:1: 0.00014396056940313429
weight iteration 6 on gpu cuda:1: 0.0009456658735871315
weight iteration 6 on gpu cuda:0: 0.0009456658735871315
grad iteration 7 on gpu cuda:1: -0.002603260101750493
grad iteration 7 on gpu cuda:0: -0.002603260101750493
weight iteration 7 on gpu cuda:1: 0.001205991953611374
weight iteration 7 on gpu cuda:0: 0.001205991953611374
grad iteration 8 on gpu cuda:0: 0.000458348571555689
grad iteration 8 on gpu cuda:1: 0.000458348571555689
weight iteration 8 on gpu cuda:0: 0.0011601571459323168
weight iteration 8 on gpu cuda:1: 0.0011601571459323168
grad iteration 9 on gpu cuda:1: -0.0004215179360471666
grad iteration 9 on gpu cuda:0: -0.0004215179360471666
weight iteration 9 on gpu cuda:1: 0.0012023089220747352
weight iteration 9 on gpu cuda:0: 0.0012023089220747352
The gradients in the two GPUs are correctly synchronized. However, when using RiemannianSGD
, the output is:
grad iteration 0 on gpu cuda:1: 0.0035285688936710358
grad iteration 0 on gpu cuda:0: 0.0035285688936710358
weight iteration 0 on gpu cuda:0: -7.928906597953755e-06
weight iteration 0 on gpu cuda:1: -7.928906597953755e-06
grad iteration 1 on gpu cuda:0: -0.04444637894630432
grad iteration 1 on gpu cuda:1: 0.002550020581111312
weight iteration 1 on gpu cuda:0: 0.00018905977776739746
weight iteration 1 on gpu cuda:1: -2.1470928913913667e-05
grad iteration 2 on gpu cuda:0: 0.009863540530204773
grad iteration 2 on gpu cuda:1: 0.0026304360944777727
weight iteration 2 on gpu cuda:0: 0.00013691956701222807
weight iteration 2 on gpu cuda:1: -3.7374120438471437e-05
grad iteration 3 on gpu cuda:0: -0.0161017756909132
grad iteration 3 on gpu cuda:1: -0.0023103044368326664
weight iteration 3 on gpu cuda:0: 0.0002405263076070696
weight iteration 3 on gpu cuda:1: -2.4383840354857966e-05
grad iteration 4 on gpu cuda:0: 0.010763526894152164
grad iteration 4 on gpu cuda:1: -0.017034146934747696
weight iteration 4 on gpu cuda:0: 0.00017218326684087515
weight iteration 4 on gpu cuda:1: 7.665574958082289e-05
grad iteration 5 on gpu cuda:0: 0.008465449325740337
weight iteration 5 on gpu cuda:0: 0.00012061965389875695
grad iteration 5 on gpu cuda:1: 0.0011690922547131777
weight iteration 5 on gpu cuda:1: 6.942617619642988e-05
grad iteration 6 on gpu cuda:0: 0.0013559082290157676
weight iteration 6 on gpu cuda:0: 0.00011242596519878134
grad iteration 6 on gpu cuda:1: 0.0008932517375797033
weight iteration 6 on gpu cuda:1: 6.43157254671678e-05
grad iteration 7 on gpu cuda:0: 0.02651313878595829
weight iteration 7 on gpu cuda:0: -1.233588955074083e-05
grad iteration 7 on gpu cuda:1: -0.007853103801608086
weight iteration 7 on gpu cuda:1: 0.00010782096069306135
grad iteration 8 on gpu cuda:0: 0.009321866557002068
weight iteration 8 on gpu cuda:0: -7.130965968826786e-05
grad iteration 8 on gpu cuda:1: -0.0039948648773133755
grad iteration 9 on gpu cuda:0: -0.0119229881092906
weight iteration 8 on gpu cuda:1: 0.00013168319128453732
weight iteration 9 on gpu cuda:0: -1.202533780997328e-06
grad iteration 9 on gpu cuda:1: 0.002446404891088605
weight iteration 9 on gpu cuda:1: 0.00011651107342913747
There is some problem with the gradient synchronization, which causes the weights in the two devices to diverge.
Library version information:
-
python -c 'import torch;print("torch:", torch.version.__version__, end=" ");print("cuda:", torch.version.cuda)'
torch: 1.8.1 cuda: 11.1
-
the way you installed geoopt
, github, pip
pip
-
OS
Ubuntu 18.04.5 LTS
EDIT: I simplified a little bit the code by removing mixed precision.
bug