图标识别
import $ from 'jquery';
import * as tf from '@tensorflow/tfjs';
import { img2x, file2img } from './utils.js';
const MODEL_PATH = 'http://127.0.0.1:8080';
const BRAND_CLASSES = ['android', 'apple', 'windows'];
$(async () => {
//引入外部模型
const mobilenet = await tf.loadLayersModel(MODEL_PATH+'/mobilenet/web_model/model.json');
//查看模型所有层
mobilenet.summary();
//截断模型
const layer = mobilenet.getLayer('conv_pw_13_relu');
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output
});
//引入外部模型
const model = await tf.loadLayersModel(MODEL_PATH + '/mobilenet/logo-model.json');
window.predict = async (file) => {
//处理上传的图片格式
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const x = img2x(img);
//把数据传入第一个引入的模型,处理完给第二个模型
const input = truncatedMobilenet.predict(x);
//第二个模型预测
return model.predict(input);
});
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${BRAND_CLASSES[index]}`);
}, 0);
};
});
html部分
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Document</title>
</head>
<body>
<div>图标识别</div>
<input type="file" onchange="predict(this.files[0])">
<button onclick="download()">保存模型</button>
</body>
<script src="./t6.js"></script>
</html>
utils.js
import * as tf from '@tensorflow/tfjs';
//转换图片格式tensor↓↓↓↓↓↓↓↓↓↓↓
export function img2x(imgEl) {
return tf.tidy(() => {
const input = tf.browser.fromPixels(imgEl)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
return input;
});
}
export function file2img(f) {
return new Promise(resolve => {
const reader = new FileReader();
reader.readAsDataURL(f);
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.width = 224;
img.height = 224;
img.onload = () => resolve(img);
};
});
}
//转换图片格式tensor↑↑↑↑↑↑↑↑↑↑↑
执行结果