hi, first of all, thanks to the author for such an excellent work. Now I want to do some experiments on top of your work.
However, when I rewrite the Dataset class, some baffling exception occurrs:
Traceback (most recent call last):
File "run_train.py", line 377, in main
torch.multiprocessing.spawn(fn=subprocess_fn, args=(args,), nprocs=args.num_gpus)
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 199, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
while not context.join():
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 118, in join
raise Exception(msg)
Exception:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
fn(i, *args)
File "/home/notebook/code/personal/80299039/MoFaStyleNeRF/run_train.py", line 301, in subprocess_fn
training_loop.training_loop(**args)
File "/home/notebook/code/personal/80299039/MoFaStyleNeRF/training/training_loop.py", line 150, in training_loop
dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//world_size, **data_loader_kwargs))
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 352, in __iter__
return self._get_iterator()
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 294, in _get_iterator
return _MultiProcessingDataLoaderIter(self)
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 801, in __init__
w.start()
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/multiprocessing/process.py", line 112, in start
self._popen = self._Popen(self)
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/multiprocessing/context.py", line 223, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/multiprocessing/context.py", line 284, in _Popen
return Popen(process_obj)
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/multiprocessing/popen_spawn_posix.py", line 32, in __init__
super().__init__(process_obj)
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/multiprocessing/popen_fork.py", line 20, in __init__
self._launch(process_obj)
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "/home/notebook/code/personal/80299039/conda/envs/StyleNeRF/lib/python3.7/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
TypeError: cannot serialize '_io.BufferedReader' object
Note that I can run your code without a hitch regardless of single GPU or 8 GPUs. So I guess the problem may come from Dataset class I have rewritten. Here is the Dataset I have rewritten:
class ImageParamFolderDataset(Dataset):
def __init__(self,
path, # Image path to directory or zip.
param_path, # Param path to directory or zip.
resolution = None, # Ensure specific resolution, None = highest available.
**super_kwargs, # Additional arguments for the Dataset base class.
):
self._path = path
self._param_path = param_path
self._zipfile = None
self._param_zipfile = None
if os.path.isdir(self._path):
self._type = 'dir'
self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
self._all_pnames = {os.path.relpath(os.path.join(root, fname), start=self._param_path) for root, _dirs, files in os.walk(self._param_path) for fname in files}
elif self._file_ext(self._path) == '.zip':
self._type = 'zip'
self._all_fnames = set(self._get_zipfile().namelist())
self._all_pnames = set(self._get_param_zipfile().namelist())
else:
raise IOError('Path must point to a directory or zip')
PIL.Image.init()
self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
self._param_fnames = sorted(pname for pname in self._all_pnames if self._file_ext(pname) == '.mat')
if len(self._image_fnames) == 0:
raise IOError('No image files found in the specified path')
if len(self._param_fnames) == 0:
raise IOError('No param files found in the specified path')
if len(self._image_fnames) != len(self._param_fnames):
raise IOError('Num of image files and num of param files are not equal')
name = os.path.splitext(os.path.basename(self._path))[0]
raw_shape = [len(self._image_fnames)] + list(self._load_raw_image_param(0)[0].shape)
if resolution is not None:
raw_shape[2] = raw_shape[3] = resolution
# if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
# raise IOError('Image files do not match the specified resolution')
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
@staticmethod
def _file_ext(fname):
return os.path.splitext(fname)[1].lower()
def _get_zipfile(self):
assert self._type == 'zip'
if self._zipfile is None:
self._zipfile = zipfile.ZipFile(self._path)
return self._zipfile
def _get_param_zipfile(self):
assert self._type == 'zip'
if self._param_zipfile is None:
self._param_zipfile = zipfile.ZipFile(self._param_path)
return self._param_zipfile
def _open_file(self, fname):
if self._type == 'dir':
return open(os.path.join(self._path, fname), 'rb')
if self._type == 'zip':
return self._get_zipfile().open(fname, 'r')
return None
def _open_param_file(self, fname):
if self._type == 'dir':
return open(os.path.join(self._param_path, fname), 'rb')
if self._type == 'zip':
return self._get_param_zipfile().open(fname, 'r')
return None
def close(self):
try:
if self._zipfile is not None:
self._zipfile.close()
if self._param_zipfile is not None:
self._param_zipfile.close()
finally:
self._zipfile = None
self._param_zipfile = None
def __getstate__(self):
return dict(super().__getstate__(), _zipfile=None)
def __getitem__(self, idx):
image, param = self._load_raw_image_param(self._raw_idx[idx])
assert isinstance(image, np.ndarray)
assert list(image.shape) == self.image_shape
assert image.dtype == np.uint8
if self._xflip[idx]:
assert image.ndim == 3 # CHW
image = image[:, :, ::-1]
return image.copy(), param, self.get_label(idx), idx
def _load_raw_image_param(self, raw_idx):
fname = self._image_fnames[raw_idx]
pname = self._param_fnames[raw_idx]
assert os.path.splitext(fname)[0] == os.path.splitext(pname)[0], 'Path of image and param must be the same'
with self._open_file(fname) as f:
if pyspng is not None and self._file_ext(fname) == '.png':
image = pyspng.load(f.read())
else:
image = np.array(PIL.Image.open(f))
with self._open_param_file(pname) as f:
param_dict = sio.loadmat(f)
param = self._process_param_dict(param_dict)
if image.ndim == 2:
image = image[:, :, np.newaxis] # HW => HWC
if hasattr(self, '_raw_shape') and image.shape[0] != self.resolution: # resize input image
image = cv2.resize(image, (self.resolution, self.resolution), interpolation=cv2.INTER_AREA)
image = image.transpose(2, 0, 1) # HWC => CHW
return image, param
def _process_param_dict(self, param_dict):
id = param_dict['id']; exp = param_dict['exp']
tex = param_dict['tex']; gamma = param_dict['gamma']
angle = param_dict['angle']; trans = param_dict['trans']
return np.concatenate((id, exp, tex, gamma, angle, trans), axis=None)
def _load_raw_labels(self):
fname = 'dataset.json'
if fname not in self._all_fnames:
return None
with self._open_file(fname) as f:
labels = json.load(f)['labels']
if labels is None:
return None
labels = dict(labels)
labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
labels = np.array(labels)
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
return labels
def get_dali_dataloader(self, batch_size, world_size, rank, gpu): # TODO
from nvidia.dali import pipeline_def, Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIGenericIterator
@pipeline_def
def pipeline():
jpegs, _ = fn.readers.file(
file_root=self._path,
files=list(self._all_fnames),
random_shuffle=True,
shard_id=rank,
num_shards=world_size,
name='reader')
images = fn.decoders.image(jpegs, device='mixed')
mirror = fn.random.coin_flip(probability=0.5) if self.xflip else False
images = fn.crop_mirror_normalize(
images.gpu(), output_layout="CHW", dtype=types.UINT8, mirror=mirror)
labels = np.zeros([1, 0], dtype=np.float32)
return images, labels
dali_pipe = pipeline(batch_size=batch_size//world_size, num_threads=2, device_id=gpu)
dali_pipe.build()
training_set_iterator = DALIGenericIterator([dali_pipe], ['img', 'label'])
for data in training_set_iterator:
yield data[0]['img'], data[0]['label']
Could you please give some suggestions about how to fix it? Thanks in advance!