As title, I got the following errors when my module is wrapped with ORTModule
[E custom_kernel.cpp:123] default_program(14): error: identifier "tensor" is undefined
1 error detected in the compilation of "default_program".
Failed to use NVRTC for JIT compilation in this Pytorch version, try another approach using CUDA compiler.. (To always disable NVRTC, please: export USE_NVRTC=0)
/tmp/torch-tutel-o0geuH.cu(14): error: identifier "tensor" is undefined
1 error detected in the compilation of "/tmp/torch-tutel-o0geuH.cu".
/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_training_manager.py:224: UserWarning: Fast path enabled - skipping checks. Rebuild graph: True, Execution agent: True, Device check: True
f" Device check: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE)}", UserWarning)
RuntimeError: There was an error while exporting the PyTorch model to ONNX:
Traceback (most recent call last):
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_utils.py", line 254, in get_exception_as_string
raise exception
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 389, in _get_exported_model
**self._export_extra_kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/onnx/__init__.py", line 280, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/opt/conda/lib/python3.7/site-packages/torch/onnx/utils.py", line 94, in export
use_external_data_format=use_external_data_format)
File "/opt/conda/lib/python3.7/site-packages/torch/onnx/utils.py", line 695, in _export
dynamic_axes=dynamic_axes)
File "/opt/conda/lib/python3.7/site-packages/torch/onnx/utils.py", line 459, in _model_to_graph
_retain_param_name)
File "/opt/conda/lib/python3.7/site-packages/torch/onnx/utils.py", line 422, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/opt/conda/lib/python3.7/site-packages/torch/onnx/utils.py", line 373, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/opt/conda/lib/python3.7/site-packages/torch/jit/_trace.py", line 1160, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/jit/_trace.py", line 132, in forward
self._force_outplace,
File "/opt/conda/lib/python3.7/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_io.py", line 430, in forward
return _transform_output_to_flat_tuple(self._original_module(*new_args, **new_kwargs))
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward
result = self.forward(*input, **kwargs)
File "test_ort.py", line 17, in forward
x_cumsum = fast_cumsum_sub_one(x, dim=0)
File "/opt/conda/lib/python3.7/site-packages/tutel/jit_kernels/gating.py", line 83, in fast_cumsum_sub_one
return get_cumsum_kernel(data.size(0), data.size(1))(data)
File "/opt/conda/lib/python3.7/site-packages/tutel/jit_kernels/gating.py", line 72, in optimized_cumsum
base_kernel(mask1.to(torch.int32).contiguous(), locations1)
File "/opt/conda/lib/python3.7/site-packages/tutel/impls/jit_compiler.py", line 26, in func
tutel_custom_kernel.invoke(inputs, __ctx__)
RuntimeError: (true) == (fp != nullptr)INTERNAL ASSERT FAILED at "/tmp/pip-req-build-qjgbz25n/tutel/custom/custom_kernel.cpp":39, please report a bug to PyTorch. CHECK_EQ fails.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "test_ort.py", line 24, in <module>
output = cumsum_module(input)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_utils.py", line 309, in _forward
return ortmodule._torch_module.forward(*inputs, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_utils.py", line 289, in _forward
torch_module_ort.is_training()).forward(*inputs, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 292, in forward
log_level=self._debug_options.logging.log_level)
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_fallback.py", line 151, in handle_exception
raise exception
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 231, in forward
build_gradient_graph = self._export_model(*inputs, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 322, in _export_model
schema, *inputs, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 392, in _get_exported_model
RuntimeError(f'There was an error while exporting the PyTorch model to ONNX: '
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_fallback_exceptions.py", line 72, in wrap_exception
raise new_exception(raised_exception) from raised_exception
onnxruntime.training.ortmodule._fallback_exceptions.ORTModuleONNXModelException: There was an error while exporting the PyTorch model to ONNX:
Traceback (most recent call last):
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_utils.py", line 254, in get_exception_as_string
raise exception
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 389, in _get_exported_model
**self._export_extra_kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/onnx/__init__.py", line 280, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/opt/conda/lib/python3.7/site-packages/torch/onnx/utils.py", line 94, in export
use_external_data_format=use_external_data_format)
File "/opt/conda/lib/python3.7/site-packages/torch/onnx/utils.py", line 695, in _export
dynamic_axes=dynamic_axes)
File "/opt/conda/lib/python3.7/site-packages/torch/onnx/utils.py", line 459, in _model_to_graph
_retain_param_name)
File "/opt/conda/lib/python3.7/site-packages/torch/onnx/utils.py", line 422, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/opt/conda/lib/python3.7/site-packages/torch/onnx/utils.py", line 373, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/opt/conda/lib/python3.7/site-packages/torch/jit/_trace.py", line 1160, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/jit/_trace.py", line 132, in forward
self._force_outplace,
File "/opt/conda/lib/python3.7/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/onnxruntime/training/ortmodule/_io.py", line 430, in forward
return _transform_output_to_flat_tuple(self._original_module(*new_args, **new_kwargs))
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward
result = self.forward(*input, **kwargs)
File "test_ort.py", line 17, in forward
x_cumsum = fast_cumsum_sub_one(x, dim=0)
File "/opt/conda/lib/python3.7/site-packages/tutel/jit_kernels/gating.py", line 83, in fast_cumsum_sub_one
return get_cumsum_kernel(data.size(0), data.size(1))(data)
File "/opt/conda/lib/python3.7/site-packages/tutel/jit_kernels/gating.py", line 72, in optimized_cumsum
base_kernel(mask1.to(torch.int32).contiguous(), locations1)
File "/opt/conda/lib/python3.7/site-packages/tutel/impls/jit_compiler.py", line 26, in func
tutel_custom_kernel.invoke(inputs, __ctx__)
RuntimeError: (true) == (fp != nullptr)INTERNAL ASSERT FAILED at "/tmp/pip-req-build-qjgbz25n/tutel/custom/custom_kernel.cpp":39, please report a bug to PyTorch. CHECK_EQ fails.
To reproduce the problem, please try the following code, thanks.
from torch_ort import ORTModule
from onnxruntime.training import ortmodule
ortmodule.ONNX_OPSET_VERSION=12
from onnxruntime.training.ortmodule._custom_autograd_function import enable_custom_autograd_support
enable_custom_autograd_support()
from tutel.jit_kernels.gating import fast_cumsum_sub_one
import torch
class CumsumModule(torch.nn.Module):
def __init__(self):
super(CumsumModule, self).__init__()
self.param = torch.nn.Parameter(torch.ones(5, 5))
def forward(self, x):
x = x + self.param
x_cumsum = fast_cumsum_sub_one(x, dim=0)
return x_cumsum
input = torch.randint(0, 5, (5, 5), device='cuda:0')
cumsum_module = CumsumModule().to(device='cuda:0')
cumsum_module = ORTModule(cumsum_module)
output = cumsum_module(input)