Tensorflow

tensorflow 콜백함수: ModelCheckpoint

카카오그래놀라 2020. 11. 17. 20:05

tf.keras.callbacks.ModelCheckpoint

tensorflow, 케라스 콜백함수
ModelCheckpoint

모델을 저장할 때 사용되는 콜백함수입니다.

www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch', options=None, **kwargs
)

 

인자 설명

인자 설명
filepath 모델을 저장할 경로를 입력합니다.
추가 설명으로 만약 monitor가 val_loss일 때,
모델 경로를 '{epoch:02d}-{val_loss:.5f}.h5' 라고 입력하면, 에폭-해당에폭에서의 val_loss.h5로 모델이 저장됩니다. 예: 01-0.39121.h5 로 저장됩니다.

monitor 모델을 저장할 때, 기준이 되는 값을 지정합니다.
예를 들어, validation set의 loss가 가장 작을 때 저장하고 싶으면 'val_loss'를 입력하고
만약 train set의 loss가 가장 작을 때 모델을 저장하고 싶으면 'loss'를 입력합니다.
이 외에도 다양한 값들을 기준으로 삼을 수 있습니다.

verbose 0, 1

1일 경우 모델이 저장 될 때, '저장되었습니다' 라고 화면에 표시되고,
0일 경우 화면에 표시되는 것 없이 그냥 바로 모델이 저장됩니다.

save_best_only True, False

True 인 경우, monitor 되고 있는 값을 기준으로 가장 좋은 값으로 모델이 저장됩니다.
False인 경우, 매 에폭마다 모델이 filepath{epoch}으로 저장됩니다. (model0, model1, model2....)

save_weights_only True, False

True인 경우, 모델의 weights만 저장됩니다.
False인 경우, 모델 레이어 및 weights 모두 저장됩니다.

mode 'auto', 'min', 'max'

val_acc 인 경우, 정확도이기 때문에 클수록 좋습니다. 따라서 이때는 max를 입력해줘야합니다.
만약 val_loss 인 경우, loss 값이기 때문에 값이 작을수록 좋습니다. 따라서 이때는 min을 입력해줘야합니다.
auto로 할 경우, 모델이 알아서 min, max를 판단하여 모델을 저장합니다.

save_freq 'epoch' 또는 integer(정수형 숫자)

'epoch'을 사용할 경우, 매 에폭마다 모델이 저장됩니다.
integer을 사용할 경우, 숫자만큼의 배치를 진행되면 모델이 저장됩니다.
예를 들어 숫자 8을 입력하면, 8번째 배치가 train 된 이후, 16번째 배치가 train 된 이후 ..... 모델이 저장됩니다.

options tf.train.CheckpointOptions를 옵션으로 줄 수 있습니다. 분산환경에서 다른 디렉토리에 모델을 저장하고 싶을 경우 사용합니다. 자세한 내용은 아래 링크를 참조해주세요.
www.tensorflow.org/api_docs/python/tf/train/CheckpointOptions

 

 

간단한 사용 예시는 아래 링크를 참조해주세요!

텐서플로우 콜백 함수(tensorflow callback)

 

텐서플로우 콜백 함수(tensorflow callback)

https://keras.io/api/callbacks/ Keras documentation: Callbacks API Callbacks API A callback is an object that can perform actions at various stages of training (e.g. at the start or end of an epoch..

deep-deep-deep.tistory.com