Unexpected behavior when using Keras' EarlyStopping with custom metrics in TensorFlow 2.9
I keep running into I've been struggling with this for a few days now and could really use some help..... I've searched everywhere and can't find a clear answer. I'm facing an issue with the `EarlyStopping` callback in Keras while training a model using TensorFlow 2.9. I want to monitor a custom metric for early stopping, but it seems to not trigger the stopping as expected. My custom metric is a combination of accuracy and a custom loss function that penalizes predictions with high variance. Here is the custom metric code I implemented: ```python import tensorflow as tf class CustomMetric(tf.keras.metrics.Metric): def __init__(self, name='custom_metric', **kwargs): super(CustomMetric, self).__init__(name=name, **kwargs) self.accuracy = self.add_weight(name='accuracy', initializer='zeros') self.variance_penalty = self.add_weight(name='variance_penalty', initializer='zeros') def update_state(self, y_true, y_pred, sample_weight=None): accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y_true, axis=1), tf.argmax(y_pred, axis=1)), tf.float32)) variance_penalty = tf.reduce_mean(tf.reduce_variance(y_pred, axis=0)) self.accuracy.assign(accuracy) self.variance_penalty.assign(variance_penalty) def result(self): return self.accuracy - self.variance_penalty def reset_states(self): self.accuracy.assign(0) self.variance_penalty.assign(0) ``` When I set up my model training, I added this custom metric and configured `EarlyStopping` to monitor it: ```python from tensorflow.keras.callbacks import EarlyStopping model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=[CustomMetric()]) early_stopping = EarlyStopping(monitor='custom_metric', patience=3, restore_best_weights=True) model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=50, callbacks=[early_stopping]) ``` However, the training completes all 50 epochs even when the validation loss does not improve after several epochs, which is contrary to what I expected. The logs show that the custom metric is fluctuating but does not trigger the early stopping: ``` Epoch 12/50 - 1s - loss: 0.2345 - custom_metric: 0.6789 - val_loss: 0.2103 - val_custom_metric: 0.6700 Epoch 13/50 - 1s - loss: 0.2321 - custom_metric: 0.6795 - val_loss: 0.2055 - val_custom_metric: 0.6750 Epoch 14/50 - 1s - loss: 0.2310 - custom_metric: 0.6800 - val_loss: 0.2080 - val_custom_metric: 0.6705 Epoch 15/50 - 1s - loss: 0.2301 - custom_metric: 0.6815 - val_loss: 0.2032 - val_custom_metric: 0.6710 ``` I've checked that `EarlyStopping` should work with custom metrics, but it seems like the callback is not recognizing changes in the metric properly. I also verified that the metric is correctly returning a single value. Is there something I'm missing in the configuration or implementation that would cause this behavior? This is part of a larger application I'm building. How would you solve this? I'm using Python stable in this project. I'd love to hear your thoughts on this. I'd love to hear your thoughts on this.