@agesmundo, hope to receive some reply, thanks.
-
Duplicate issues: #788 and #4528 are not suitable for this case.
-
How to reproduce the bug:
[1] Just run mu2Net on 8gpus A100, use BENCHMARK = 'ViT large / Chars benchmark'
[2] OOM error will occur when train_step function compiled by jax.jit is executed.
[3] The A100 have sufficent 80GiB memory per gpu, i use 8gpus. My cpu has 256g memory and 112+ cores.
[4] I can't understand why the executable needs to preallocate 114.44GiB temp allocation, though the seed ViT model is just 300M.
[5]It's useless to set any env variable about jax memory allocation
-
Model hyperparameters:
def get_exp_config_large(benchmark_string_id):
exp_config = ConfigDict()
exp_config.experiment_name = EXPERIMENT_NAME
exp_config.experiments_root_dir = EXPERIMENTS_ROOT_DIR
# Cap to 1/10th of imagenet train set size to have similar ratio of exps reported in:
# https://arxiv.org/abs/2106.10270
exp_config.num_train_examples_between_validations_max = 128_116
exp_config.num_validations_per_path_training = 4
exp_config.num_validation_examples_max = 10_000
# Fit HBM memory: TPUv4 megacore=64, TPUv3=32.
exp_config.batch_size = 64
exp_config.num_task_iters = 1
# Assuming TPUv4 32 cores * 4 generations.
exp_config.num_samples_per_task = 32 * 4
exp_config.mutate_adapters = False
exp_config.force_finetune_components = ['encoder_norm']
# Population policy params:
exp_config.policy_class = 'PPDecay'
exp_config.policy_kwargs = {}
# Scorer params:
exp_config.scorer_class = 'ScorerDecay'
exp_config.scorer_kwargs = dict(
base=1.0,
num_params=303_303_682, # Params in L/16
)
# Seed models params:
exp_config.load_rand_init = False
exp_config.load_vit_checkpoint = True
exp_config.load_vit_checkpoint_query = 'name=="L/16" and ds=="i21k" and aug=="medium2" and wd==0.03 and sd==0.1'
exp_config.load_experiment = False
exp_config.load_experiment_dir = ''
set_continue_configs(exp_config)
# Hyperparameters:
max_num_layers = get_max_num_layers(exp_config.load_vit_checkpoint_query)
exp_config.models_default_hparams = {
'_mu_': 0.2,
'num_classes': 1,
'adapter_layers': '',
'num_layers': max_num_layers,
'adapter_dim': 16,
'opt_lr': 0.01,
'opt_lr_schedule': 'cosine',
'opt_lr_warmup_ratio': 0.05,
'opt_momentum': 0.9,
'opt_nesterov': False,
'ds_image_size': 384,
'ds_crop': True,
'ds_area_range_min': 0.05,
'ds_aspect_ratio_range_min': 0.75,
'ds_flip_left_right': True,
'ds_brightness_delta': 0.0,
'ds_contrast_delta': 0.0,
'ds_saturation_delta': 0.0,
'ds_hue_delta': 0.0,
}
@partial(jax.jit, static_argnames=['model', 'optimizer'], donate_argnums=[0, 2])
def train_step(params, fixed_params, opt_state, images, labels, model, optimizer):
def loss_fn(params, fixed_params, images, labels):
logits = model.apply({'params': format_params(params, fixed_params)},
images, train=USE_DROPOUT)
labels = jax.nn.one_hot(labels, logits.shape[-1])
return -jnp.mean(jnp.sum(labels * nn.log_softmax(logits), axis=-1))
grads = jax.grad(loss_fn)(params, fixed_params, images, labels)
updates, opt_state = optimizer.update(grads, opt_state, params=params)
params = optax.apply_updates(params, updates)
return params, opt_state
def train_loop(paths, ds_train, ds_validation, devices, exp_config):
global LOOP_START
timing = {'start_time': time.time(),
'start_time_loop': LOOP_START}
task = paths[0].task
# The following values should be shared by all paths in this generation batch.
for path in paths:
assert task == path.task
assert paths[0].hparams['ds_image_size'] == path.hparams['ds_image_size']
gc.collect()
# Compile.
compile_train_batches_arr = jax.device_put_replicated(
get_sample_batch(
paths[0].hparams['ds_image_size'],
task.train_batch_size),
devices)
compile_eval_batches_arr = jax.device_put_replicated(
get_sample_batch(
paths[0].hparams['ds_image_size'],
task.validation_batch_size),
devices)
for p_id, path in enumerate(paths):
if VERBOSE:
print('Parent')
print(prp(path.parent))
print(prp(path))
path.device_id = p_id % len(devices)
path.device = devices[path.device_id]
print("path:", p_id, "device:", path.device)
path.optimizer = path.get_optimizer()
path.optimizer_init_fn = jax.jit(path.optimizer.init, device=path.device)
path.best_params_local = None
path.best_opt_state_local = None
path.best_quality = None
path.best_score = path.parent.score() if path.task is path.parent.task else -np.inf
path.evals = []
# Launch parallel compilation of eval and train step functions.
params_local = path.get_trainable_params()
check_is_local(params_local)
path.compile_params_device = jax.device_put(params_local, path.device)
path.compile_fixed_params_device = jax.device_put(
path.get_fixed_params(),
path.device)
path.compile_train = Thread(
target=train_step,
args=(path.compile_params_device,
path.compile_fixed_params_device,
path.optimizer_init_fn(params_local),
compile_train_batches_arr['image'][path.device_id],
compile_train_batches_arr['label'][path.device_id],
path.model,
path.optimizer))
path.compile_eval = Thread(
target=eval_step,
args=(format_params(
path.compile_params_device,
path.compile_fixed_params_device),
compile_eval_batches_arr['image'][path.device_id],
compile_eval_batches_arr['label'][path.device_id],
path.model))
path.compile_eval.start()
for path in paths:
path.compile_eval.join()
del path.compile_eval
timing['end_compile_eval'] = time.time()
path.compile_train.start()
del compile_eval_batches_arr
for path in paths:
path.compile_train.join()
del path.compile_train
del path.compile_params_device
del path.compile_fixed_params_device
timing['end_compile'] = time.time()
del compile_train_batches_arr
gc.collect()
# Parameter transfer.
for path in paths:
path.params_device = jax.device_put(
path.get_trainable_params(),
path.device)
path.fixed_params_device = jax.device_put(
path.get_fixed_params(),
path.device)
path.opt_state_device = path.optimizer_init_fn(path.params_device)
# Set opt state.
for c in path.components:
if c.is_trainable():
assert c.name in path.opt_state_device[1][0].trace.keys()
if c.opt_state is not None:
path.opt_state_device = (
path.opt_state_device[0],
(optax.TraceState(
trace=path.opt_state_device[1][0].trace.copy(
{c.name: jax.device_put(c.opt_state,
path.device)})),
path.opt_state_device[1][1]
)
)
check_is_on_device(path.opt_state_device, path.device)
iter_ds_validation = iter(ds_validation)
# TRAIN
for t_step, train_batch in zip(
range(exp_config.num_validations_per_path_training
* task.num_train_batches_between_validations),
ds_train,
):
train_batch_arr = jax.device_put_replicated(train_batch, devices)
for p_id, path in enumerate(paths):
if t_step == 0:
timing['end_prep'] = time.time()
t_step_0_time = time.time()
train_step_start = time.time()
path.params_device, path.opt_state_device = train_step(
path.params_device,
path.fixed_params_device,
path.opt_state_device,
train_batch_arr['image'][path.device_id],
train_batch_arr['label'][path.device_id],
path.model,
path.optimizer)
if t_step == 0 and time.time() - t_step_0_time > 1:
print(f'WARNING: First train step took: {time.time()-t_step_0_time:.2f} s')
del train_batch, train_batch_arr
# EVAL
# ...
- Full error messages/tracebacks:
Exception in thread Thread-14:
Traceback (most recent call last):
File "/mnt/lustre/liujun1/.conda/envs/muNet/lib/python3.7/threading.py", line 890, in _bootstrap
self._bootstrap_inner()
File "/mnt/lustre/liujun1/.conda/envs/muNet/lib/python3.7/threading.py", line 926, in _bootstrap_inner
self.run()
File "/mnt/lustre/liujun1/.conda/envs/muNet/lib/python3.7/threading.py", line 870, in run
self._target(*self._args, **self._kwargs)
File "/mnt/lustre/liujun1/.conda/envs/muNet/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/mnt/lustre/liujun1/.conda/envs/muNet/lib/python3.7/site-packages/jax/_src/api.py", line 476, in cache_miss
donated_invars=donated_invars, inline=inline, keep_unused=keep_unused)
File "/mnt/lustre/liujun1/.conda/envs/muNet/lib/python3.7/site-packages/jax/core.py", line 1765, in bind
return call_bind(self, fun, *args, **params)
File "/mnt/lustre/liujun1/.conda/envs/muNet/lib/python3.7/site-packages/jax/core.py", line 1781, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/mnt/lustre/liujun1/.conda/envs/muNet/lib/python3.7/site-packages/jax/core.py", line 678, in process_call
return primitive.impl(f, *tracers, **params)
File "/mnt/lustre/liujun1/.conda/envs/muNet/lib/python3.7/site-packages/jax/_src/dispatch.py", line 185, in _xla_call_impl
return compiled_fun(*args)
File "/mnt/lustre/liujun1/.conda/envs/muNet/lib/python3.7/site-packages/jax/_src/dispatch.py", line 615, in _execute_compiled
out_bufs_flat = compiled.execute(input_bufs_flat)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 122875791936 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 1.43GiB
constant allocation: 8B
maybe_live_out allocation: 581.19MiB
preallocated temp allocation: 114.44GiB
preallocated temp fragmentation: 146.50MiB (0.13%)
total allocation: 115.86GiB
total fragmentation: 146.53MiB (0.12%)
Peak buffers:
Buffer 1:
Size: 1.27GiB
XLA Label: custom-call
Shape: f32[64,16,577,577]
==========================
Buffer 2:
Size: 1.27GiB
XLA Label: custom-call
Shape: f32[64,16,577,577]
==========================
Buffer 3:
Size: 1.27GiB
XLA Label: custom-call
Shape: f32[64,16,577,577]
==========================
Buffer 4:
Size: 1.27GiB
XLA Label: custom-call
Shape: f32[64,16,577,577]
==========================
Buffer 5:
Size: 1.27GiB
XLA Label: custom-call
Shape: f32[64,16,577,577]
==========================
Buffer 6:
Size: 1.27GiB
XLA Label: custom-call
Shape: f32[64,16,577,577]
==========================
-
environment:
python: 3.7.13
jax: 0.3.14/0.3.13
jaxlib: 0.3.14+cuda11.cudnn82/0.3.10+cuda11.cudnn82
-
python packages:
absl-py 1.1.0
aqtp 0.0.7
astunparse 1.6.3
cachetools 5.2.0
certifi 2022.6.15
charset-normalizer 2.0.12
chex 0.1.3
cloudpickle 2.1.0
clu 0.0.3
colorama 0.4.5
commonmark 0.9.1
contextlib2 21.6.0
cycler 0.11.0
dacite 1.6.0
decorator 5.1.1
dill 0.3.5.1
dm-tree 0.1.7
einops 0.3.0
etils 0.6.0
flatbuffers 1.12
flax 0.5.2
flaxformer 0.4.2
fonttools 4.33.3
gast 0.4.0
google-auth 2.8.0
google-auth-oauthlib 0.4.6
google-pasta 0.2.0
googleapis-common-protos 1.56.3
grpcio 1.47.0
h5py 3.7.0
idna 3.3
importlib-metadata 4.12.0
importlib-resources 5.8.0
jax 0.3.14
jaxlib 0.3.14+cuda11.cudnn82
keras 2.9.0
Keras-Preprocessing 1.1.2
kiwisolver 1.4.3
libclang 14.0.1
Markdown 3.3.7
matplotlib 3.5.2
ml-collections 0.1.1
msgpack 1.0.4
numpy 1.21.6
oauthlib 3.2.0
opt-einsum 3.3.0
optax 0.1.2
packaging 21.3
pandas 1.3.5
Pillow 9.1.1
pip 21.2.2
promise 2.3
protobuf 3.19.4
pyasn1 0.4.8
pyasn1-modules 0.2.8
Pygments 2.12.0
pyparsing 3.0.9
python-dateutil 2.8.2
pytz 2022.1
PyYAML 6.0
requests 2.28.0
requests-oauthlib 1.3.1
rich 11.2.0
rsa 4.8
scipy 1.7.3
setuptools 61.2.0
six 1.16.0
tensorboard 2.9.1
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
tensorflow-addons 0.17.1
tensorflow-cpu 2.9.1
tensorflow-datasets 4.6.0
tensorflow-estimator 2.9.0
tensorflow-hub 0.12.0
tensorflow-io-gcs-filesystem 0.26.0
tensorflow-metadata 1.9.0
tensorflow-probability 0.17.0
tensorflow-text 2.9.0
termcolor 1.1.0
toml 0.10.2
toolz 0.11.2
tqdm 4.64.0
typeguard 2.13.3
typing_extensions 4.2.0
urllib3 1.26.9
Werkzeug 2.1.2
wheel 0.37.1
wrapt 1.14.1
zipp 3.8.0