CodexBloom - Programming Q&A Platform

how to to Set Learning Rate Schedule in TensorFlow 2.12 with Custom Training Loop

πŸ‘€ Views: 100 πŸ’¬ Answers: 1 πŸ“… Created: 2025-06-17
tensorflow keras machine-learning Python

I'm working through a tutorial and I'm sure I'm missing something obvious here, but I'm having trouble implementing a learning rate schedule in my custom training loop using TensorFlow 2.12. I want to use an exponential decay schedule, but it seems like my learning rate isn't updating as expected. Here’s a snippet of my training code: ```python import tensorflow as tf import numpy as np # Sample data x_train = np.random.rand(1000, 10) y_train = np.random.rand(1000, 1) # Model definition model = tf.keras.Sequential([ tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)), tf.keras.layers.Dense(1) ]) # Custom training loop optimizer = tf.keras.optimizers.Adam(learning_rate=0.01) # Learning rate schedule initial_learning_rate = 0.1 lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps=100, decay_rate=0.96, staircase=True ) for epoch in range(10): print(f'Epoch {epoch + 1}') for step in range(100): with tf.GradientTape() as tape: y_pred = model(x_train) loss = tf.keras.losses.mean_squared_error(y_train, y_pred) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) # Update learning rate lr = lr_schedule(epoch * 100 + step) optimizer.learning_rate.assign(lr) print(f'Step {step}, Learning Rate: {lr.numpy()}, Loss: {loss.numpy().mean()}') ``` However, when I run this code, I notice that the learning rate printed remains constant at 0.01 throughout the training, and the loss doesn't decrease as I would expect. I've tried adjusting the `decay_steps` and even setting `learning_rate` directly in the optimizer, but nothing seems to apply the learning rate schedule correctly. Is there something I'm missing in how to properly leverage the `ExponentialDecay` schedule in this context? Any insights would be appreciated! Thanks in advance! I'm working on a service that needs to handle this. I'm using Python 3.11 in this project. Any help would be greatly appreciated! This issue appeared after updating to Python 3.9. Is this even possible?