原文链接:https://*.com/questions/59027150/keras-training-freezes-during-fit-generator
一般来说,我们可以使用Keras
包中的fit
函数进行模型的训练,其参数如下:
Model.fit(
x=None,
y=None,
batch_size=None,
epochs=1,
verbose="auto",
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
)
其中x
和validation_data
可以是事先加载好的数据组成的tuple
——(inputs, targets)
,也可以是根据Keras相关API构建的Data Generator(包括ImageDataGenerator
、keras.utils.Sequence
等),在训练的过程中,这些数据会按设定好的batch_size被喂给模型,从而完成train
和evaluate
。
当我们在使用generator向模型中输入数据的时候,在部分高版本的Keras(>2.0.0)
中可能会出现第一个epoch训练结束,但是evalute过程不结束,表现为第一个epoch卡住的情况。
根据相关资料和笔者自身经验,强烈建议在调用fit
函数时,显式地指出step_per_epoch
和validation_steps
的值,从而解决epoch卡住无法结束的问题