We get memory overflow when we run too many SCATrEx tree inferences in the same Python session. This only happens when using a GPU, and it occurs on the call to do_grad
, which is a jitted function.
Full error trace:
RuntimeError Traceback (most recent call last)
<ipython-input-216-9b90a41ab615> in <module>
47 'random_seed': 1,
48 'verbosity': logging.DEBUG}
---> 49 sca.learn_tree(reset=True, batch_key='batch', search_kwargs=search_kwargs)
/cluster/work/bewi/members/pedrof/miniconda3/envs/py38/lib/python3.8/site-packages/scatrex/scatrex.py in learn_tree(self, observed_tree, reset, cell_filter, filter_genes, max_genes, batch_key, search_kwargs)
289 logger.info("Will continue search from where it left off.")
290
--> 291 self.ntssb = self.search.run_search(**search_kwargs)
292 self.ntssb.create_augmented_tree_dict()
293
/cluster/work/bewi/members/pedrof/miniconda3/envs/py38/lib/python3.8/site-packages/scatrex/ntssb/search.py in run_search(self, n_iters, n_iters_elbo, factor_delay, posterior_delay, global_delay, joint_init, thin, local, num_samples, step_size, verbosity, tol, mb_size, max_nodes, debug, callback, alpha, Tmax, anneal, restart_step, move_weights, weighted, merge_n_tries, opt, search_callback, add_rule, add_rule_thres, random_seed, **callback_kwargs)
290 "log_baseline_mean"
291 ] = init_log_baseline
--> 292 self.tree.optimize_elbo(
293 root_node=None,
294 sticks_only=True,
/cluster/work/bewi/members/pedrof/miniconda3/envs/py38/lib/python3.8/site-packages/scatrex/ntssb/ntssb.py in optimize_elbo(self, root_node, local_node, global_only, sticks_only, unique_node, num_samples, n_iters, thin, step_size, debug, tol, run, max_nodes, init, opt, opt_triplet, mb_size, callback, **callback_kwargs)
1543 # data_mask_subset = data_mask
1544 # start = time.time()
-> 1545 opt_state, g, params, elbo = self.update(
1546 obs_params,
1547 parent_vector,
/cluster/work/bewi/members/pedrof/miniconda3/envs/py38/lib/python3.8/site-packages/scatrex/ntssb/ntssb.py in update(self, obs_params, parent_vector, children_vector, ancestor_nodes_indices, tssb_indices, previous_branches_indices, tssb_weights, dp_alphas, dp_gammas, node_mask, data_mask_subset, indices, do_global, global_only, sticks_only, num_samples, i, opt_state, opt_update, get_params)
1259 # print("Recompiling update!")
1260 params = get_params(opt_state)
-> 1261 value, gradient = self.do_grad(
1262 obs_params,
1263 parent_vector,
[... skipping hidden 8 frame]
RuntimeError: INTERNAL: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory