Inconsistent Gradient Updates with tf.GradientTape and Mixed Precision in TensorFlow 2.12
I need help solving I'm experimenting with After trying multiple solutions online, I still can't figure this out. I'm stuck on something that should probably be simple... I'm experiencing inconsistent gradient updates when using `tf.GradientTape` in combination with mixed precision training in TensorFlow 2.12. I've set up my model using the `tf.keras.mixed_precision` API, and my training loop looks something like this: ```python import tensorflow as tf from tensorflow.keras import layers, models # Enable mixed precision policy = tf.keras.mixed_precision.Policy('mixed_float16') tf.keras.mixed_precision.set_global_policy(policy) # Simple model model = models.Sequential([ layers.Dense(128, activation='relu', input_shape=(784,)), layers.Dense(10) ]) optimizer = tf.keras.optimizers.Adam() # Custom training loop for epoch in range(num_epochs): with tf.GradientTape() as tape: predictions = model(train_data, training=True) loss = loss_fn(train_labels, predictions) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) ``` The model trains, but I've noticed that the loss does not decrease consistently across epochs, and sometimes I get very large updates that lead to oscillations in the loss value. I have tried switching the optimizer to both Adam and SGD, but the issue persists. I've also attempted to adjust the learning rate using `tf.keras.callbacks.LearningRateScheduler`, but this only marginally improved stability. Additionally, I verified that my input data is properly normalized and didn't find any anomalies there. On inspecting the gradients, I noticed that some of them are `NaN` or have extremely large values, which is unexpected. When I disable mixed precision by setting the global policy back to `'float32'`, the training behaves as expected, but I want to leverage the performance benefits of mixed precision. Does anyone have insights on how to stabilize gradient updates while using mixed precision training? Are there specific configurations or practices I'd need to follow to avoid these inconsistencies? Any ideas what could be causing this? What would be the recommended way to handle this? For context: I'm using Python on Debian. I appreciate any insights! The stack includes Python and several other technologies.