tensorflow.js基本使用 截断模型、引入外部模型(七)

图标识别

import $ from 'jquery';
import * as tf from '@tensorflow/tfjs';
import { img2x, file2img } from './utils.js';

const MODEL_PATH = 'http://127.0.0.1:8080';
const BRAND_CLASSES = ['android', 'apple', 'windows'];

$(async () => {
  //引入外部模型
  const mobilenet = await tf.loadLayersModel(MODEL_PATH+'/mobilenet/web_model/model.json');

  //查看模型所有层
  mobilenet.summary();

  //截断模型
  const layer = mobilenet.getLayer('conv_pw_13_relu');

  const truncatedMobilenet = tf.model({
    inputs: mobilenet.inputs,
    outputs: layer.output
  });

  //引入外部模型
  const model = await tf.loadLayersModel(MODEL_PATH + '/mobilenet/logo-model.json');

  window.predict = async (file) => {
    //处理上传的图片格式
    const img = await file2img(file);
    document.body.appendChild(img);
    
    const pred = tf.tidy(() => {
      const x = img2x(img);

      //把数据传入第一个引入的模型,处理完给第二个模型
      const input = truncatedMobilenet.predict(x);

      //第二个模型预测
      return model.predict(input);
    });

    const index = pred.argMax(1).dataSync()[0];
    setTimeout(() => {
      alert(`预测结果:${BRAND_CLASSES[index]}`);
    }, 0);
  };
});

html部分

<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <title>Document</title>
</head>
<body>
  <div>图标识别</div>
  <input type="file" onchange="predict(this.files[0])">
  <button onclick="download()">保存模型</button>
</body>
<script src="./t6.js"></script>
</html>

utils.js

import * as tf from '@tensorflow/tfjs';

//转换图片格式tensor↓↓↓↓↓↓↓↓↓↓↓
export function img2x(imgEl) {
  return tf.tidy(() => {
    const input = tf.browser.fromPixels(imgEl)
      .toFloat()
      .sub(255 / 2)
      .div(255 / 2)
      .reshape([1, 224, 224, 3]);
    return input;
  });
}

export function file2img(f) {
  return new Promise(resolve => {
    const reader = new FileReader();
    reader.readAsDataURL(f);
    reader.onload = (e) => {
      const img = document.createElement('img');
      img.src = e.target.result;
      img.width = 224;
      img.height = 224;
      img.onload = () => resolve(img);
    };
  });
}
//转换图片格式tensor↑↑↑↑↑↑↑↑↑↑↑

执行结果

tensorflow.js基本使用 截断模型、引入外部模型(七) 

 

上一篇:Business Model - Organizations And Organizations Theory


下一篇:pytorch训练和验证