Hey guys!
Really epic work in this repo! I'm currently working on integrating this into Lightning (any assistance would be appreciated). From what I see the ORTModule just wraps the forward function, converting it into ONNX format? As a result I've internally in Lightning wrapped the model to ensure that user defined functions (training_step
validation_step
test_step
) are placed in a wrapped modules' forward function.
Currently I'm running into an error:
/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_io.py:473: UserWarning: This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX. Compute will continue, but unexpected results may occur!
warnings.warn("This model cannot be deep copied (or pickled), "
2021-07-19 13:20:57.590381944 [E:onnxruntime:, inference_session.cc:1341 operator()] Exception during initialization: /onnxruntime_src/onnxruntime/core/framework/session_state_utils.cc:143 onnxruntime::common::Status onnxruntime::session_state_utils::SaveInitializedTensors(const onnxruntime::Env&, const std::basic_string<char>&, const onnxruntime::GraphViewer&, const AllocatorPtr&, const onnxruntime::OrtValueNameIdxMap&, const std::vector<int>&, onnxruntime::ITensorAllocator&, const std::function<onnxruntime::common::Status(int, const OrtValue&, const onnxruntime::OrtCallback&, bool)>&, const onnxruntime::logging::Logger&, const onnxruntime::DataTransferManager&, const onnxruntime::ExecutionPlanBase&, const onnxruntime::SessionOptions&) ort_value_name_idx_map.MaxIdx() > -1 was false. OrtValue indexes should have been populated.
Traceback (most recent call last):
File "reproduce_test.py", line 99, in <module>
run()
File "reproduce_test.py", line 94, in run
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
File "/data/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 515, in fit
self._run(model)
File "/data/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 896, in _run
self._dispatch()
File "/data/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 963, in _dispatch
self.accelerator.start_training(self)
File "/data/pytorch-lightning/pytorch_lightning/accelerators/accelerator.py", line 97, in start_training
self.training_type_plugin.start_training(trainer)
File "/data/pytorch-lightning/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
self._results = trainer.run_stage()
File "/data/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 973, in run_stage
return self._run_train()
File "/data/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 1008, in _run_train
self._run_sanity_check(self.lightning_module)
File "/data/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 1084, in _run_sanity_check
self._evaluation_loop.run()
File "/data/pytorch-lightning/pytorch_lightning/loops/base.py", line 112, in run
self.advance(*args, **kwargs)
File "/data/pytorch-lightning/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 122, in advance
self.num_dataloaders,
File "/data/pytorch-lightning/pytorch_lightning/loops/base.py", line 112, in run
self.advance(*args, **kwargs)
File "/data/pytorch-lightning/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
output = self.evaluation_step(batch, batch_idx, dataloader_idx)
File "/data/pytorch-lightning/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 162, in evaluation_step
output = self.trainer.accelerator.validation_step(step_kwargs)
File "/data/pytorch-lightning/pytorch_lightning/accelerators/accelerator.py", line 220, in validation_step
return self.training_type_plugin.validation_step(*step_kwargs.values())
File "reproduce_test.py", line 74, in validation_step
return self.model(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/ortmodule.py", line 41, in _forward
return self._execution_manager(self._is_training()).forward(*inputs, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_inference_manager.py", line 86, in forward
self._create_execution_agent()
File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_inference_manager.py", line 115, in _create_execution_agent
session_options, providers, provider_options)
File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_execution_agent.py", line 52, in __init__
self.create_inference_agent(path_or_bytes, session_options, providers, provider_options)
File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_execution_agent.py", line 56, in create_inference_agent
providers, provider_options)
File "/usr/local/lib/python3.6/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 283, in __init__
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/usr/local/lib/python3.6/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 321, in _create_inference_session
sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Exception during initialization: /onnxruntime_src/onnxruntime/core/framework/session_state_utils.cc:143 onnxruntime::common::Status onnxruntime::session_state_utils::SaveInitializedTensors(const onnxruntime::Env&, const std::basic_string<char>&, const onnxruntime::GraphViewer&, const AllocatorPtr&, const onnxruntime::OrtValueNameIdxMap&, const std::vector<int>&, onnxruntime::ITensorAllocator&, const std::function<onnxruntime::common::Status(int, const OrtValue&, const onnxruntime::OrtCallback&, bool)>&, const onnxruntime::logging::Logger&, const onnxruntime::DataTransferManager&, const onnxruntime::ExecutionPlanBase&, const onnxruntime::SessionOptions&) ort_value_name_idx_map.MaxIdx() > -1 was false. OrtValue indexes should have been populated.
With the script (requires you to install pytorch lightning, pip install pytorch-lightning
):
import os
import pickle
import torch
from torch.utils.data import DataLoader, Dataset
from torch_ort import ORTModule
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins import SingleDevicePlugin
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def unwrap_lightning_module(wrapped_model) -> 'pl.LightningModule':
model = wrapped_model
if isinstance(model, LightningDistributedModule):
model = unwrap_lightning_module(model.module)
if isinstance(model, ORTModule):
model = unwrap_lightning_module(model._module_metadata.original_module)
return model
class ORTPlugin(SingleDevicePlugin):
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
pickle.dumps(model)
import pdb;pdb.set_trace()
self.model = ORTModule(LightningDistributedModule(self.model))
self.model_to_device()
return self.model
@property
def lightning_module(self):
return unwrap_lightning_module(self._model)
def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
max_epochs=1,
plugins=ORTPlugin(device=torch.device('cuda:0')),
gpus=1,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
if __name__ == '__main__':
run()
I'll continue to debug in the meantime :)