advanced patterns with TensorFlow's Model.fit() when using custom callbacks
I'm stuck trying to I keep running into This might be a silly question, but I'm working on a machine learning project using TensorFlow 2.10 and I've run into an scenario with the `Model.fit()` function when using a custom callback to monitor the training process. I defined a simple callback to log the loss and accuracy at the end of each epoch, but it seems to be interfering with the training process, causing the model not to save the best weights as expected and leading to premature stopping. Hereβs my custom callback class: ```python import tensorflow as tf class CustomCallback(tf.keras.callbacks.Callback): def __init__(self): super().__init__() self.best_loss = float('inf') def on_epoch_end(self, epoch, logs=None): current_loss = logs.get('loss') current_acc = logs.get('accuracy') print(f'Epoch: {epoch + 1}, Loss: {current_loss}, Accuracy: {current_acc}') if current_loss < self.best_loss: self.best_loss = current_loss print('Best model weights updated!') self.model.save_weights('best_model.h5') else: print('No improvement found.') ``` When I call `model.fit()` with this custom callback included: ```python model.fit(train_data, train_labels, epochs=10, callbacks=[CustomCallback()]) ``` I'm seeing the output as expected, but the model doesn't save the weights correctly, and sometimes it even gives me a warning that says: ``` UserWarning: ModelCheckpoint is not saving the best weights as the monitored quantity did not improve. ``` Also, the loss does decrease but not as expected through epochs, and I suspect that itβs because of some interference with the callback logic. I've tried debugging by checking if the `current_loss` and `self.best_loss` are updating correctly, and they are. However, Iβm not sure if there are any best practices or known issues with using callbacks in the latest TensorFlow version that might be affecting this. Any insights or suggestions would be greatly appreciated! Is there a better approach? Could this be a known issue? How would you solve this?