When I try and run the following cell:
cifar10_test_prelogits, cifar10_test_logits, cifar10_test_labels = standalone_get_prelogits(
params,
cifar10_ds_test,
image_count=N_test
)
I get the error message below. Please help me resolve this issue :)
ERROR MESSAGE:
UnfilteredStackTrace: RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[128,24,24,1024]{2,1,3,0}, u8[0]{0}) custom-call(f32[128,384,384,3]{2,1,3,0} %copy, f32[16,16,3,1024]{1,0,2,3} %copy.1), window={size=16x16 stride=16x16}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(16, 16) padding=((0, 0), (0, 0)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 lhs_shape=(128, 384, 384, 3) rhs_shape=(16, 16, 3, 1024) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py" source_line=371}, backend_config="{"conv_result_scale":1,"activation_mode":"0","side_input_scale":0}"
Original error: UNIMPLEMENTED: DNN library is not found.
To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py in call(self, inputs)
369 dimension_numbers=dimension_numbers,
370 feature_group_count=self.feature_group_count,
--> 371 precision=self.precision
372 )
373 else:
RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[128,24,24,1024]{2,1,3,0}, u8[0]{0}) custom-call(f32[128,384,384,3]{2,1,3,0} %copy, f32[16,16,3,1024]{1,0,2,3} %copy.1), window={size=16x16 stride=16x16}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(16, 16) padding=((0, 0), (0, 0)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 lhs_shape=(128, 384, 384, 3) rhs_shape=(16, 16, 3, 1024) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py" source_line=371}, backend_config="{"conv_result_scale":1,"activation_mode":"0","side_input_scale":0}"
Original error: UNIMPLEMENTED: DNN library is not found.
To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning.