Saikat Kumar Dey

Training Image Classification Models with scikeras

June 25, 2022

Introduction to scikeras

Scikeras is a powerful tool that allows users to combine the capabilities of TensorFlow and scikit-learn in order to train deep learning models. By using scikeras, users can take advantage of the strengths of both TensorFlow and scikit-learn, and benefit from the ease of use and flexibility of scikit-learn’s API.

In this blog post, we will demonstrate how to use scikeras to train a model using a large dataset of images stored on disk. We will show how to use TensorFlow’s ImageDataGenerator() function to load images in batches and apply real-time augmentation, and how to use scikeras’s KerasClassifier() to create a scikit-learn compatible interface for training the model. We will also demonstrate how to use partial_fit() to train the model on smaller batches of data and retain the history of model weights and parameters.

Download dataset

Download a sample dataset and store the dataset in data/. Your directory structure should look like the following:

data/Pistachio_Image_Dataset
├── Kirmizi_Pistachio/*.jpg
└── Siirt_Pistachio/*.jpg

Import necessary libraries

from math import ceil
import tensorflow as tf
from matplotlib import pyplot as plt
from scikeras.wrappers import KerasClassifier
from tensorflow.keras.preprocessing.image import ImageDataGenerator

Setup Constants

Next, we need to set some constants that will be used throughout the training process. In this example, we will be using a batch size of 32 and training the model for 10 epochs. You may need to adjust these values depending on your dataset and the performance of your model.

DATA_DIR = "data/Pistachio_Image_Dataset"
BATCH_SIZE = 32
EPOCHS = 10

Loader for reading data in batches

One of the key advantages of using scikeras is the ability to train a model using large datasets that do not fit into memory. To do this, we can use TensorFlow’s ImageDataGenerator() function to load images in batches and apply real-time augmentation. This allows us to train the model on smaller chunks of the dataset, without having to load the entire dataset into memory.

image_generator = ImageDataGenerator(rescale=1.0 / 255).flow_from_directory(
    DATA_DIR,
    target_size=(32, 32),
    batch_size=BATCH_SIZE,
    class_mode="binary",
)
total_images = len(image_generator.filenames)
total_batches = ceil(total_images // BATCH_SIZE)

In this code, we are using ImageDataGenerator() to create a generator that will load the images in DATA_DIR in batches of size BATCH_SIZE, apply a rescaling factor of 1/255, and return the images and labels in a binary format. We then calculate the total number of images in the dataset and the total number of batches. These values will be used later in the training loop.

Define your Tensorflow model architecture

Next, we need to define the architecture of our TensorFlow model. For the purposes of this example, we will be using a simple shallow-net with a single dense layer. However, you can use any architecture that you prefer, and you can experiment with different architectures to see which one performs best on your dataset.

model = tf.keras.Sequential(
    [
        tf.keras.layers.Input(shape=(32, 32, 3)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(1, activation="sigmoid"),
    ]
)

Define scikeras interface

Once we have defined our TensorFlow model, we can use scikeras’s KerasClassifier() function to create a scikit-learn compatible interface for training the model. This allows us to use the familiar fit() and predict() methods from scikit-learn, while taking advantage of the capabilities of TensorFlow.

sk_clf = KerasClassifier(
    model=model,
    optimizer="adam",
    loss="binary_crossentropy",
    metrics=["accuracy"],
)

In this code, we are creating a KerasClassifier object and passing it our TensorFlow model, as well as some other parameters such as the optimizer and loss function to use during training. We are also specifying that we want to track the accuracy metric during training. You can adjust these parameters as needed to suit your specific use.

Training loop

Now that we have set up the necessary components for training our model, we can implement the main training loop. This loop will iterate over the batches of images generated by ImageDataGenerator(), and will use partial_fit() to train the model on each batch. partial_fit() has the advantage of allowing us to train the model on smaller batches of data, and it also retains the history of model weights and parameters, whereas fit() resets this history every time it is called.

batch = 0
epoch = 0
histories = []
for X, y in image_generator:
    sk_clf.partial_fit(X, y, verbose=False)
    history = sk_clf.model_.history.history
    histories.append(history)
    batch += 1
    if batch == total_batches:
        batch = 0
        epoch += 1
        print(
            f"epoch {epoch}/{EPOCHS}, loss {history['loss'][0]} accuracy {history['accuracy'][0]}"
        )
    if epoch == EPOCHS:
        break

In this code, we are looping over the batches generated by ImageDataGenerator() and calling partial_fit() on each batch. We are also keeping track of the history of model weights and parameters, and we are printing the loss and accuracy for each epoch. Once we have reached the specified number of epochs, the training loop will exit and the model will be trained. At this point, you can use the predict() method to make predictions on new data, or you can continue training the model using partial_fit() to improve its performance further.

Conclusion

In this blog post, we have demonstrated how to use scikeras to train a TensorFlow model on a large dataset of images stored on disk. We have shown how to use ImageDataGenerator() to load images in batches and apply real-time augmentation, how to use KerasClassifier() to create a scikit-learn compatible interface for training the model, and how to use partial_fit() to train the model on smaller batches of data and retain the history of model weights and parameters.

Overall, scikeras is a convenient and effective tool for training image classification models, and it allows users to take advantage of the strengths of both TensorFlow and scikit-learn. There are many potential avenues for further improvement, such as experimenting with different model architectures, fine-tuning the training parameters, and applying more advanced augmentation techniques. We hope that this blog post has provided a useful introduction to scikeras and has given you some ideas for how to use it in your own projects.