Saikat Kumar Dey

Memory-efficient scikeras model training

I want to train an image classification model using scikeras interface.

Scikeras is a wrapper that allows us to combine Tensorflow with sklearn.

I have a dataset of images on my disk which I would like to use for training.

Tensorflow has ImageDataGenerator() which allows us to load images in batches and apply augmentation (rotation, flipping, zoom,scaling) in real-time.

scikeras exposes KerasClassifier() which is scikit-learn compatible API for model training.

model.fit() in scikit-learn expects the entire dataset to be loaded in memory.

If our image dataset is huge, our RAM will not be able to hold all of it in memory.

So, we would like to use partial_fit() with smaller batches of images generated by ImageDataGenerator. partial_fit() also keeps the history of model weights and parameters. fit() resets it every time it’s called.

Let’s begin.

Download dataset

Download a sample dataset and store the dataset in data/. Your directory structure should look like the following:

data/Pistachio_Image_Dataset
├── Kirmizi_Pistachio/*.jpg
└── Siirt_Pistachio/*.jpg

Import necessary libraries

from math import ceil
import tensorflow as tf
from matplotlib import pyplot as plt
from scikeras.wrappers import KerasClassifier
from tensorflow.keras.preprocessing.image import ImageDataGenerator

Setup Constants

DATA_DIR = "data/Pistachio_Image_Dataset"
BATCH_SIZE = 32
EPOCHS = 10

Loader for reading data in batches

image_generator = ImageDataGenerator(rescale=1.0 / 255).flow_from_directory(
    DATA_DIR,
    target_size=(32, 32),
    batch_size=BATCH_SIZE,
    class_mode="binary",
)
total_images = len(image_generator.filenames)
total_batches = ceil(total_images // BATCH_SIZE)

Define your Tensorflow model architecture

We’ll use a shallow-net for demonstration purposes.

model = tf.keras.Sequential(
    [
        tf.keras.layers.Input(shape=(32, 32, 3)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(1, activation="sigmoid"),
    ]
)

Define scikeras interface

sk_clf = KerasClassifier(
    model=model,
    optimizer="adam",
    loss="binary_crossentropy",
    metrics=["accuracy"],
)

Training loop

batch = 0
epoch = 0
histories = []
for X, y in image_generator:
    sk_clf.partial_fit(X, y, verbose=False)
    history = sk_clf.model_.history.history
    histories.append(history)
    batch += 1
    if batch == total_batches:
        batch = 0
        epoch += 1
        print(
            f"epoch {epoch}/{EPOCHS}, loss {history['loss'][0]} accuracy {history['accuracy'][0]}"
        )
    if epoch == EPOCHS:
        break