Explore Courses Blog Tutorials Interview Questions
0 votes
in Machine Learning by (19k points)

I wish to implement early stopping with Keras and sklean's GridSearchCV.

The working code example below is modified from How to Grid Search Hyperparameters for Deep Learning Models in Python With Keras. The data set may be downloaded from here.

The modification adds the Keras EarlyStopping callback class to prevent over-fitting. For this to be effective it requires the monitor='val_acc' argument for monitoring validation accuracy. For val_acc to be available KerasClassifier requires the validation_split=0.1 to generate validation accuracy, else EarlyStopping raises RuntimeWarning: Early stopping requires val_acc available!. Note the  FIXME: code comment!

Note we could replace val_acc by val_loss!

Question: How can I use the cross-validation data set generated by the GridSearchCV k-fold algorithm instead of wasting 10% of the training data for an early stopping validation set?

# Use scikit-learn to grid search the learning rate and momentum

import numpy

from sklearn.model_selection import GridSearchCV

from keras.models import Sequential

from keras.layers import Dense

from keras.wrappers.scikit_learn import KerasClassifier

from keras.optimizers import SGD

# Function to create model, required for KerasClassifier

def create_model(learn_rate=0.01, momentum=0):

    # create model

    model = Sequential()

    model.add(Dense(12, input_dim=8, activation='relu'))

    model.add(Dense(1, activation='sigmoid'))

    # Compile model

    optimizer = SGD(lr=learn_rate, momentum=momentum)

    model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

    return model

# Early stopping

from keras.callbacks import EarlyStopping

stopper = EarlyStopping(monitor='val_acc', patience=3, verbose=1)

# fix random seed for reproducibility

seed = 7


# load dataset

dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")

# split into input (X) and output (Y) variables

X = dataset[:,0:8]

Y = dataset[:,8]

# create model

model = KerasClassifier(


    epochs=100, batch_size=10,

    validation_split=0.1, # FIXME: Instead use GridSearchCV k-fold validation data.


1 Answer

0 votes
by (33.1k points)
edited by

The entire CV idea is implicitly based on the "all other being equal" argument. If you feel that the number of epochs should be a hyperparameter, just include it explicitly in your CV as such, rather than inserting it through the back door of early stopping, thus possibly compromising the whole process.

More details about this will be taken from the Tensorflow Tutorial.

Watch this video to know more about Keras:

Browse Categories