TensorFlow 2.12: Strange Behavior in Model Predictions After Fine-Tuning with Custom Loss Function
I've been researching this but I'm working on a project and hit a roadblock. I've spent hours debugging this and I'm experiencing unexpected behavior with my TensorFlow model after fine-tuning it with a custom loss function. Initially, I trained a CNN for image classification on a standard dataset and achieved good performance. However, after fine-tuning with a custom loss that incorporates both cross-entropy and a regularization term, I noticed that the predictions during evaluation are not aligning with the expected outputs. Here is the custom loss function I defined: ```python import tensorflow as tf def custom_loss(y_true, y_pred): cross_entropy = tf.keras.losses.binary_crossentropy(y_true, y_pred) regularization = 0.01 * tf.reduce_sum(tf.square(y_pred)) return cross_entropy + regularization ``` I compile the model using this loss function as follows: ```python model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy']) ``` After training for several epochs, I evaluate the model on a validation set: ```python val_loss, val_accuracy = model.evaluate(validation_data) print(f'Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}') ``` Surprisingly, the validation accuracy is significantly lower than what I observed during the initial training phase, even though the training loss continues to decrease. When I inspect the predictions with: ```python predictions = model.predict(validation_data) print(predictions) ``` I find that many of the predicted outputs are either very close to zero or one, indicating a potential scenario with overfitting or an imbalance in the loss components. I've tried adjusting the weighting of the regularization term but haven't seen any improvement. Has anyone encountered a similar scenario or has suggestions on how to balance the custom loss better? Could there be a question with how I'm incorporating the regularization term? Any insights would be greatly appreciated. Thanks for any help you can provide! I've been using Python for about a year now. Has anyone else encountered this? What am I doing wrong?