Issues with TensorFlow 2.8 when training a multi-class classification model with imbalanced data
I'm prototyping a solution and I've been researching this but I've been struggling with this for a few days now and could really use some help. I'm trying to train a multi-class classification model using TensorFlow 2.8, but I've run into issues with my imbalanced dataset. My classes are significantly skewed; class 0 has 80% of the data, while class 1 has 10%, and class 2 has only 10%. Iβve implemented a simple feedforward neural network and used categorical cross-entropy as my loss function. However, when I evaluate the model, I notice the accuracy is high, but the model appears to be biased towards class 0. Hereβs the code snippet of my model and training process: ```python import tensorflow as tf from tensorflow.keras.layers import Dense from tensorflow.keras.models import Sequential from sklearn.model_selection import train_test_split from sklearn.preprocessing import OneHotEncoder # Generate some synthetic data for illustration X = ... # feature matrix Y = ... # target labels (0, 1, 2) # One-hot encode the target variable encoder = OneHotEncoder(sparse=False) Y_encoded = encoder.fit_transform(Y.reshape(-1, 1)) # Split the dataset X_train, X_test, Y_train, Y_test = train_test_split(X, Y_encoded, test_size=0.2, random_state=42) model = Sequential([ Dense(64, activation='relu', input_shape=(X_train.shape[1],)), Dense(64, activation='relu'), Dense(3, activation='softmax') ]) model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) # Train the model model.fit(X_train, Y_train, epochs=50, batch_size=32, validation_split=0.2) ``` After training, when I check the classification report, it shows that the model has an F1 score of 0.90 for class 0, but only 0.10 for classes 1 and 2. Iβve tried using class weights to balance the training process: ```python class_weight = {0: 1, 1: 8, 2: 8} model.fit(X_train, Y_train, epochs=50, batch_size=32, validation_split=0.2, class_weight=class_weight) ``` However, I still observe the same issue, where the model remains biased towards class 0. I also tried oversampling classes 1 and 2 using SMOTE and undersampling class 0, but the model's performance still doesn't improve significantly. What strategies can I employ to ensure that my model learns effectively from all classes in this imbalanced setting? Are there specific techniques or best practices in TensorFlow that can help me address this issue? For context: I'm using Python on Ubuntu. Is there a better approach?