Tensorflow

텐서플로우 콜백 함수(tensorflow callback)

카카오그래놀라 2020. 5. 16. 14:33

https://keras.io/api/callbacks/

 

Keras documentation: Callbacks API

Callbacks API A callback is an object that can perform actions at various stages of training (e.g. at the start or end of an epoch, before or after a single batch, etc). You can use callbacks to: Write TensorBoard logs after every batch of training to moni

keras.io

 

콜백함수가 필요한 이유: 모델이 학습을 시작하면 학습이 완료될 때까지 사람이 할 수 있는게 없습니다. 따라서 이를 해결하고자 존재하는 것이 콜백함수입니다. 예를 들어, 학습되는 과정 사이에 학습률을 변화시키거나 val_loss가 개선되지 않으면 학습을 멈추게 하는 등의 작업을 할 수 있습니다.

 

from tensorflow.keras import Sequential, Input
from tensorflow.keras.layers import Dense, Flatten
from tf.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger


model= Sequential()
model.add(Input(shape=(250, 250, 3))) # 250x250 RGB images
model.add(layers.Conv2D(32, 5, strides=2, activation="relu"))
model.add(layers.Conv2D(32, 3, activation="relu"))
model.add(Flatten())
model.add(Dense(10), activation="softmax")

# 콜백 함수
es = EarlyStopping(patience=20)
mc = ModelCheckpoint("your_path/file_name.h5", save_best_only=True)
rlr = ReduceLROnPlateau(factor=0.1, patience=5)
csvlogger = CSVLogger("your_path/file_name.log")

sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)

model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])

# fit 안에 위의 콜백 함수를 넣어주면 됩니다.
model.fit(x_train, y_train, epochs=20, batch_size=128, callbacks=[es, mc, rlr, csvlogger])

 

 

자세한 설명은 아래 링크를 참조해주세요!

EarlyStopping

 

tensorflow 콜백함수: EarlyStopping

tensorflow, 케라스 콜백함수 ModelCheckpoint 모델을 저장할 때 사용되는 콜백함수입니다. www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping tf.keras.callbacks.EarlyStopping( monitor='..

deep-deep-deep.tistory.com

ModelCheckpoint

 

Tensorflow 콜백함수: ModelCheckpoint

Tensorflow 콜백함수: ModelCheckpoint tf.keras.callbacks.ModelCheckpoint( filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', save_freq='epoch', op..

deep-deep-deep.tistory.com

ReduceLROnPlateau

 

Tensorflow 콜백함수: ReduceLROnPlateau

Tensorflow, 케라스 콜백함수 ReduceLROnPlateau 모델의 개선이 없을 경우, Learning Rate를 조절해 모델의 개선을 유도하는 콜백함수입니다. www.tensorflow.org/api_docs/python/tf/keras/callbacks/ReduceLROn..

deep-deep-deep.tistory.com