📅  最后修改于: 2022-03-11 14:46:25.397000             🧑  作者: Mango
# random search cross validation in neural network model
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.wrappers.scikit_learn import KerasClassifier
def report(results, n_top=3):
for i in range(1, n_top + 1):
candidates = numpy.flatnonzero(results['rank_test_score'] == i)
for candidate in candidates:
print("Model with rank: {0}".format(i))
print("Mean validation score: {0:.3f} (std: {1:.3f})"
.format(results['mean_test_score'][candidate],
results['std_test_score'][candidate]))
print("Parameters: {0}".format(results['params'][candidate]))
print("")
def nn_model(activation = 'relu', neurons = 32, optimizer = 'Adam',dropout = 0.1, init_mode = 'uniform'):
model = Sequential()
model.add(Dense(32, input_dim = 32, kernel_initializer = init_mode, activation= activation))
model.add(Dense((neurons*2)//3, kernel_initializer = init_mode,activation= activation))
model.add(Dense((neurons*4)//9,kernel_initializer = init_mode, activation = activation))
model.add(Dropout(dropout))
model.add(Dense(1, kernel_initializer = init_mode, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer= optimizer, metrics=['accuracy'])
return model
# Defining grid parameters
activation = ['softmax', 'softplus', 'softsign', 'relu', 'selu', 'elu', 'tanh','sigmoid', 'linear']
neurons = range(31,39)
dropout = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
init_mode = ['uniform', 'lecun_uniform', 'normal', 'zero', 'glorot_normal', 'glorot_uniform', 'he_normal', 'he_uniform']
optimizer = ['SGD', 'Adam', 'Adamax','RMSprop','Adagrad','Adadelta','Nadam','Ftrl']
batch_size = range(10,101,10)
param_grid = dict(activation = activation, neurons = neurons, optimizer = optimizer, dropout = dropout, init_mode = init_mode, batch_size = batch_size)
clf = KerasClassifier(build_fn= nn_model, epochs= 10, verbose= 1)
model = RandomizedSearchCV(estimator= clf, param_distributions = param_grid, n_jobs=-1,verbose = 3)
model.fit(stan_x_train,stan_y_train)
report(model.cv_results_)