之前,我们文章里有讲,通过softmax回归对图片进行分类,具体文章请见《使用Softmax进行分类代码实现》。今天我们通过高级API更简洁地实现多层感知机。
准备数据集
private static RandomAccessDataset getDataset(Dataset.Usage usage)
throws IOException {
Mnist mnist =
Mnist.builder()
.optUsage(usage)
.setSampling(32, true)
.optLimit(64)
.build();
mnist.prepare(new ProgressBar());
return mnist;
}
//训练集
RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN);
//验证集
RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST);
模型定义
Block block = new Mlp(
Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
Mnist.NUM_CLASSES,
new int[] {128, 64});
初始化训练器
private static DefaultTrainingConfig setupTrainingConfig() {
String outputDir = "build/model";
SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir);
listener.setSaveModelCallback(
trainer -> {
TrainingResult result = trainer.getTrainingResult();
Model model = trainer.getModel();
float accuracy = result.getValidateEvaluation("Accuracy");
model.setProperty("Accuracy", String.format("%.5f", accuracy));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
});
return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
}
开始训练并保存模型
try (Model model = Model.newInstance("mlp")) {
model.setBlock(block);
try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(new Metrics());
Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);
trainer.initialize(inputShape);
EasyTrain.fit(trainer, 15, trainingSet, validateSet);
//保存模型
model.save(Paths.get("build/model"), "mlp");
return trainer.getTrainingResult();
}
}
总结
在这里,我们经过15个Epoch的训练,最终在build/model
目录下,生成我们训练好的模型。后续我们将通使用我们训练的模型进行图片分类预测。关注公众号,解锁后续图片预测部分实现。