While I want to do a distributed training including training on Google Colab TPU, errors as shown below would occurs:
/usr/local/lib/python3.7/dist-packages/tensorflow/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
528 self._self_setattr_tracking = False # pylint: disable=protected-access
529 try:
--> 530 result = method(self, *args, **kwargs)
531 finally:
532 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
/usr/local/lib/python3.7/dist-packages/keras/engine/training_v1.py in compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, distribute, **kwargs)
434 targets=self._targets,
435 skip_target_masks=self._prepare_skip_target_masks(),
--> 436 masks=self._prepare_output_masks())
437
438 # Prepare sample weight modes. List with the same length as model outputs.
/usr/local/lib/python3.7/dist-packages/keras/engine/training_v1.py in _handle_metrics(self, outputs, targets, skip_target_masks, sample_weights, masks, return_weighted_metrics, return_weighted_and_unweighted_metrics)
1962 metric_results.extend(
1963 self._handle_per_output_metrics(self._per_output_metrics[i],
-> 1964 target, output, output_mask))
1965 if return_weighted_and_unweighted_metrics or return_weighted_metrics:
1966 metric_results.extend(
/usr/local/lib/python3.7/dist-packages/keras/engine/training_v1.py in _handle_per_output_metrics(self, metrics_dict, y_true, y_pred, mask, weights)
1913 with backend.name_scope(metric_name):
1914 metric_result = training_utils_v1.call_metric_function(
-> 1915 metric_fn, y_true, y_pred, weights=weights, mask=mask)
1916 metric_results.append(metric_result)
1917 return metric_results
/usr/local/lib/python3.7/dist-packages/keras/engine/training_utils_v1.py in call_metric_function(metric_fn, y_true, y_pred, weights, mask)
1175
1176 if y_pred is not None:
-> 1177 return metric_fn(y_true, y_pred, sample_weight=weights)
1178 # `Mean` metric only takes a single value.
1179 return metric_fn(y_true, sample_weight=weights)
/usr/local/lib/python3.7/dist-packages/keras/metrics.py in __call__(self, *args, **kwargs)
235 from keras.distribute import distributed_training_utils # pylint:disable=g-import-not-at-top
236 return distributed_training_utils.call_replica_local_fn(
--> 237 replica_local_fn, *args, **kwargs)
238
239 def __str__(self):
/usr/local/lib/python3.7/dist-packages/keras/distribute/distributed_training_utils.py in call_replica_local_fn(fn, *args, **kwargs)
58 with strategy.scope():
59 return strategy.extended.call_for_each_replica(fn, args, kwargs)
---> 60 return fn(*args, **kwargs)
61
62
/usr/local/lib/python3.7/dist-packages/keras/metrics.py in replica_local_fn(*args, **kwargs)
215 update_op = None
216 else:
--> 217 update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable
218 update_ops = []
219 if update_op is not None:
/usr/local/lib/python3.7/dist-packages/keras/utils/metrics_utils.py in decorated(metric_obj, *args, **kwargs)
71
72 with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs):
---> 73 update_op = update_state_fn(*args, **kwargs)
74 if update_op is not None: # update_op will be None in eager execution.
75 metric_obj.add_update(update_op)
/usr/local/lib/python3.7/dist-packages/keras/metrics.py in update_state_fn(*args, **kwargs)
175 control_status = tf.__internal__.autograph.control_status_ctx()
176 ag_update_state = tf.__internal__.autograph.tf_convert(obj_update_state, control_status)
--> 177 return ag_update_state(*args, **kwargs)
178 else:
179 if isinstance(obj.update_state, tf.__internal__.function.Function):
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
694 try:
695 with conversion_ctx:
--> 696 return converted_call(f, args, kwargs, options=options)
697 except Exception as e: # pylint:disable=broad-except
698 if hasattr(e, 'ag_error_metadata'):
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
381
382 if not options.user_requested and conversion.is_allowlisted(f):
--> 383 return _call_unconverted(f, args, kwargs, options)
384
385 # internal_convert_user_code is for example turned off when issuing a dynamic
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f, args, kwargs, options, update_cache)
462
463 if kwargs is not None:
--> 464 return f(*args, **kwargs)
465 return f(*args)
466
/usr/local/lib/python3.7/dist-packages/keras/metrics.py in update_state(self, y_true, y_pred, sample_weight)
723
724 ag_fn = tf.__internal__.autograph.tf_convert(self._fn, tf.__internal__.autograph.control_status_ctx())
--> 725 matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
726 return super(MeanMetricWrapper, self).update_state(
727 matches, sample_weight=sample_weight)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
694 try:
695 with conversion_ctx:
--> 696 return converted_call(f, args, kwargs, options=options)
697 except Exception as e: # pylint:disable=broad-except
698 if hasattr(e, 'ag_error_metadata'):
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
381
382 if not options.user_requested and conversion.is_allowlisted(f):
--> 383 return _call_unconverted(f, args, kwargs, options)
384
385 # internal_convert_user_code is for example turned off when issuing a dynamic
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f, args, kwargs, options, update_cache)
462
463 if kwargs is not None:
--> 464 return f(*args, **kwargs)
465 return f(*args)
466
/usr/local/lib/python3.7/dist-packages/keras/losses.py in __call__(self, y_true, y_pred, sample_weight)
141 losses = call_fn(y_true, y_pred)
142 return losses_utils.compute_weighted_loss(
--> 143 losses, sample_weight, reduction=self._get_reduction())
144
145 @classmethod
/usr/local/lib/python3.7/dist-packages/keras/losses.py in _get_reduction(self)
182 self.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE)):
183 raise ValueError(
--> 184 'Please use `tf.keras.losses.Reduction.SUM` or '
185 '`tf.keras.losses.Reduction.NONE` for loss reduction when losses are '
186 'used with `tf.distribute.Strategy` outside of the built-in training '
ValueError: Please use `tf.keras.losses.Reduction.SUM` or `tf.keras.losses.Reduction.NONE` for loss reduction when losses are used with `tf.distribute.Strategy` outside of the built-in training loops. You can implement `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch size like:
with strategy.scope():
loss_obj = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
loss = tf.reduce_sum(loss_obj(labels, predictions)) * (1. / global_batch_size)
Please see https://www.tensorflow.org/tutorials/distribute/custom_training for more details.
it seems that support of loss reduction has not been implemented.
It may be a little tricky, but it would be nice if you can add this enhancement.
Thank you!
enhancement