Изменения в библиотекте для ЛР2:

- Добавлен callback для частичного отображения эпох
- Добавлены kwargs для частичного отображения эпох и ранней остановки по дельте и значению метрики
- Для EarlyStoppingCallback добавлена зависимость от verbose_show
main
Lipisin 3 дней назад
Родитель 339fe963d0
Сommit 047d249b1f

@ -29,12 +29,14 @@ from pandas import DataFrame
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.callbacks import Callback
visual = True
verbose_show = False
# generate 2d classification dataset
def datagen(x_c, y_c, n_samples, n_features):
@ -91,8 +93,27 @@ class EarlyStoppingOnValue(tensorflow.keras.callbacks.Callback):
)
return monitor_value
class VerboseEveryNEpochs(Callback):
def __init__(self, every_n_epochs=1000, verbose=1):
super().__init__()
self.every_n_epochs = every_n_epochs
self.verbose = verbose
def on_epoch_end(self, epoch, logs=None):
if (epoch + 1) % self.every_n_epochs == 0:
if self.verbose:
print(f"\nEpoch {epoch + 1}/{self.params['epochs']}")
if logs:
log_str = ", ".join([f"{k}: {v:.4f}" for k, v in logs.items()])
print(f" - {log_str}")
#создание и обучение модели автокодировщика
def create_fit_save_ae(cl_train, ae_file, irefile, epohs, verbose_show, patience):
def create_fit_save_ae(cl_train, ae_file, irefile, epohs, verbose_show, patience, **kwargs):
verbose_every_n_epochs = kwargs.get('verbose_every_n_epochs', 1000)
early_stopping_delta = kwargs.get('early_stopping_delta', 0.01)
early_stopping_value = kwargs.get('early_stopping_value', 0.0001)
size = cl_train.shape[1]
#ans = '2'
@ -140,22 +161,28 @@ def create_fit_save_ae(cl_train, ae_file, irefile, epohs, verbose_show, patience
optimizer = tensorflow.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, amsgrad=False)
ae.compile(loss='mean_squared_error', optimizer=optimizer)
error_stop = 0.0001
epo = epohs
early_stopping_callback_on_error = EarlyStoppingOnValue(monitor='loss', baseline=error_stop)
verbose = 1 if verbose_show else 0
early_stopping_callback_on_error = EarlyStoppingOnValue(monitor='loss', baseline=early_stopping_value)
early_stopping_callback_on_improving = tensorflow.keras.callbacks.EarlyStopping(monitor='loss',
min_delta=0.0001, patience = patience,
verbose=1, mode='auto',
min_delta=early_stopping_delta, patience = patience,
verbose=verbose, mode='min',
baseline=None,
restore_best_weights=False)
restore_best_weights=True)
history_callback = tensorflow.keras.callbacks.History()
verbose = 1 if verbose_show else 0
history_object = ae.fit(cl_train, cl_train,
batch_size=cl_train.shape[0],
epochs=epo,
callbacks=[early_stopping_callback_on_error, history_callback,
early_stopping_callback_on_improving],
callbacks=[
early_stopping_callback_on_error,
history_callback,
early_stopping_callback_on_improving,
VerboseEveryNEpochs(every_n_epochs=verbose_every_n_epochs),
],
verbose=verbose)
ae_trainned = ae
ae_pred = ae_trainned.predict(cl_train)
@ -538,4 +565,4 @@ def ire_plot(title, IRE_test, IREth, ae_name):
plt.gcf().savefig('out/IRE_' + title + ae_name + '.png')
plt.show()
return
return

Загрузка…
Отмена
Сохранить