DJL-Java开发者动手学深度学习之图片分类

之前,我们文章里有讲,通过softmax回归对图片进行分类,具体文章请见《使用Softmax进行分类代码实现》。今天我们通过高级API更简洁地实现多层感知机。

准备数据集

private static RandomAccessDataset getDataset(Dataset.Usage usage)
    throws IOException {
    Mnist mnist =
        Mnist.builder()
        .optUsage(usage)
        .setSampling(32, true)
        .optLimit(64)
        .build();
    mnist.prepare(new ProgressBar());
    return mnist;
}

//训练集
RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN);

//验证集
RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST);
        

模型定义

 Block block = new Mlp(
                        Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
                        Mnist.NUM_CLASSES,
                        new int[] {128, 64});

初始化训练器

private static DefaultTrainingConfig setupTrainingConfig() {
    String outputDir = "build/model";
    SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir);
    listener.setSaveModelCallback(
        trainer -> {
            TrainingResult result = trainer.getTrainingResult();
            Model model = trainer.getModel();
            float accuracy = result.getValidateEvaluation("Accuracy");
            model.setProperty("Accuracy", String.format("%.5f", accuracy));
            model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
        });

    return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
        .addEvaluator(new Accuracy())
        .addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
        .addTrainingListeners(listener);
}

开始训练并保存模型

try (Model model = Model.newInstance("mlp")) {
    model.setBlock(block);
    try (Trainer trainer = model.newTrainer(config)) {
        trainer.setMetrics(new Metrics());
        Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);
        trainer.initialize(inputShape);
        EasyTrain.fit(trainer, 15, trainingSet, validateSet);

        //保存模型
        model.save(Paths.get("build/model"), "mlp"); 
        return trainer.getTrainingResult();
    }
}

总结

在这里,我们经过15个Epoch的训练,最终在build/model目录下,生成我们训练好的模型。后续我们将通使用我们训练的模型进行图片分类预测。关注公众号,解锁后续图片预测部分实现。

DJL-Java开发者动手学深度学习之图片分类

上一篇:Java项目:在线点餐系统(java+Springboot+Maven+mybatis+Vue+mysql+Redis)


下一篇:Code First 新增属性