I want to train an image classification model using scikeras interface.
Scikeras is a wrapper that allows us to combine Tensorflow with sklearn.
I have a dataset of images on my disk which I would like to use for training.
Tensorflow has ImageDataGenerator() which allows us to load images in batches and apply augmentation (rotation, flipping, zoom,scaling) in real-time.
scikeras exposes KerasClassifier() which is scikit-learn compatible API for model training.
model.fit()
in scikit-learn expects the entire dataset to be loaded in memory.
If our image dataset is huge, our RAM will not be able to hold all of it in memory.
So, we would like to use partial_fit()
with smaller batches of images generated by ImageDataGenerator. partial_fit()
also keeps the history of model weights and parameters. fit()
resets it every time it’s called.
Let’s begin.
Download dataset
Download a sample dataset and store the dataset in data/
. Your directory structure should look like the following:
data/Pistachio_Image_Dataset
├── Kirmizi_Pistachio/*.jpg
└── Siirt_Pistachio/*.jpg
Import necessary libraries
from math import ceil
import tensorflow as tf
from matplotlib import pyplot as plt
from scikeras.wrappers import KerasClassifier
from tensorflow.keras.preprocessing.image import ImageDataGenerator
Setup Constants
DATA_DIR = "data/Pistachio_Image_Dataset"
BATCH_SIZE = 32
EPOCHS = 10
Loader for reading data in batches
image_generator = ImageDataGenerator(rescale=1.0 / 255).flow_from_directory(
DATA_DIR,
target_size=(32, 32),
batch_size=BATCH_SIZE,
class_mode="binary",
)
total_images = len(image_generator.filenames)
total_batches = ceil(total_images // BATCH_SIZE)
Define your Tensorflow model architecture
We’ll use a shallow-net for demonstration purposes.
model = tf.keras.Sequential(
[
tf.keras.layers.Input(shape=(32, 32, 3)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(1, activation="sigmoid"),
]
)
Define scikeras interface
sk_clf = KerasClassifier(
model=model,
optimizer="adam",
loss="binary_crossentropy",
metrics=["accuracy"],
)
Training loop
batch = 0
epoch = 0
histories = []
for X, y in image_generator:
sk_clf.partial_fit(X, y, verbose=False)
history = sk_clf.model_.history.history
histories.append(history)
batch += 1
if batch == total_batches:
batch = 0
epoch += 1
print(
f"epoch {epoch}/{EPOCHS}, loss {history['loss'][0]} accuracy {history['accuracy'][0]}"
)
if epoch == EPOCHS:
break