准备工作
首先需要安装最新版的 PyTorch,本文的版本是 1.3.0.
其次需要安装 Android Studio 进行 Android 开发。
模型格式转换
为了能够在 Android 上使用我们的深度学习模型,需要将其转换为 TorchScript 格式。这个过程非常简单。下面的代码将预训练的 MobileNetV2 模型转换为 TorchScript 格式:
1
2
3
4
5
6
7
8
9
10
|
import torch
from torchvision.models import mobilenet_v2
model = mobilenet_v2(pretrained=True)
model.eval()
input_tensor = torch.rand(1,3,224,224)
script_model = torch.jit.trace(model,input_tensor)
script_model.save("mobilenet-v2.pt")
|
上述代码会将转换好的模型存为文件 “mobilenet-v2.pt”。
1. 创建 Android 项目,添加 PyTorch Mobile
首先用 Android Studio 创建一个项目名为 PytorchAndroid,然后打开 build.gradle 文件添加 PyTorch Mobile 和 TorchVision Mobile:
1
2
|
implementation ‘org.pytorch:pytorch_android:1.3.0’
implementation ‘org.pytorch:pytorch_android_torchvision:1.3.0’
|
文件示例如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
apply plugin: ‘com.android.application‘
android {
compileSdkVersion 28
defaultConfig {
applicationId "com.johnolafenwa.pytorchandroid"
minSdkVersion 21
targetSdkVersion 28
versionCode 1
versionName "1.0"
testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile(‘proguard-android-optimize.txt‘), ‘proguard-rules.pro‘
}
}
}
dependencies {
implementation fileTree(dir: ‘libs‘, include: [‘*.jar‘])
implementation ‘org.pytorch:pytorch_android:1.3.0‘
implementation ‘org.pytorch:pytorch_android_torchvision:1.3.0‘
implementation ‘com.android.support:appcompat-v7:28.0.0‘
implementation ‘com.android.support.constraint:constraint-layout:1.1.3‘
implementation ‘com.android.support:design:28.0.0‘
testImplementation ‘junit:junit:4.12‘
androidTestImplementation ‘com.android.support.test:runner:1.0.2‘
androidTestImplementation ‘com.android.support.test.espresso:espresso-core:3.0.2‘
}
|
然后 Android Studio 会提醒进行同步,点击 “Sync Now” 会自动下载所需要的依赖包。
2. 将模型放到 assets 文件夹
按照下列步骤创建 assets 文件夹:New -> Folder -> Assets Folder。然后将 “mobilenet-v2.pt” 文件放到这个 assets 文件件内。
3. 添加 ImageNet 标签
在 app 包内,创建名为 “Constants.java” 的 Java 文件,将这个文件里的内容复制粘贴进去。
4. 添加分类
在 app 包内,创建名为 “Classifier.java” 的 Java 文件,放入下列代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
|
package com.johnolafenwa.pytorchandroid;
import android.graphics.Bitmap;
import org.pytorch.Tensor;
import org.pytorch.Module;
import org.pytorch.IValue;
import org.pytorch.torchvision.TensorImageUtils;
public class Classifier {
Module model;
float[] mean = {0.485f, 0.456f, 0.406f};
float[] std = {0.229f, 0.224f, 0.225f};
public Classifier(String modelPath){
model = Module.load(modelPath);
}
public void setMeanAndStd(float[] mean, float[] std){
this.mean = mean;
this.std = std;
}
public Tensor preprocess(Bitmap bitmap, int size){
bitmap = Bitmap.createScaledBitmap(bitmap,size,size,false);
return TensorImageUtils.bitmapToFloat32Tensor(bitmap,this.mean,this.std);
}
public int argMax(float[] inputs){
int maxIndex = -1;
float maxvalue = 0.0f;
for (int i = 0; i < inputs.length; i++){
if(inputs[i] > maxvalue) {
maxIndex = i;
maxvalue = inputs[i];
}
}
return maxIndex;
}
public String predict(Bitmap bitmap){
Tensor tensor = preprocess(bitmap,224);
IValue inputs = IValue.from(tensor);
Tensor outputs = model.forward(inputs).toTensor();
float[] scores = outputs.getDataAsFloatArray();
int classIndex = argMax(scores);
return Constants.IMAGENET_CLASSES[classIndex];
}
}
|
这个是我们整个项目的核心文件。其中 preprocess
函数接收一张 bitmap 图像,然后调整大小,做标准化处理,再把处理后的文件返回为 Tensor 格式以备模型使用。argmax
函数返回最大值所在的 index。predict
函数接收一张 bitmap 图像,将其处理为 Tensor,放入模型得到预测结果。
5. 添加工具辅助类
创建文件 “Utils.java” 然后放入下列代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
package com.johnolafenwa.pytorchandroid;
import android.content.Context;
import android.util.Log;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
public class Utils {
public static String assetFilePath(Context context, String assetName) {
File file = new File(context.getFilesDir(), assetName);
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
Log.e("pytorchandroid", "Error process asset " + assetName + " to file path");
}
return null;
}
}
|
6. 添加 Main Activity
创建文件 “MainActivity.java” 然后放入下列代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
|
package com.johnolafenwa.pytorchandroid;
import android.content.Intent;
import android.graphics.Bitmap;
import android.os.Bundle;
import android.provider.MediaStore;
import android.support.v7.app.AppCompatActivity;
import android.support.v7.widget.Toolbar;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import java.io.File;
public class MainActivity extends AppCompatActivity {
int cameraRequestCode = 001;
Classifier classifier;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Toolbar toolbar = findViewById(R.id.toolbar);
setSupportActionBar(toolbar);
classifier = new Classifier(Utils.assetFilePath(this,"mobilenet-v2.pt"));
Button capture = findViewById(R.id.capture);
capture.setOnClickListener(new View.OnClickListener(){
@Override
public void onClick(View view){
Intent cameraIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
startActivityForResult(cameraIntent,cameraRequestCode);
}
});
}
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data){
if(requestCode == cameraRequestCode && resultCode == RESULT_OK){
Intent resultView = new Intent(this,Result.class);
resultView.putExtra("imagedata",data.getExtras());
Bitmap imageBitmap = (Bitmap) data.getExtras().get("data");
String pred = classifier.predict(imageBitmap);
resultView.putExtra("pred",pred);
startActivity(resultView);
}
}
}
|
文件 “activity_main.xml” 应该长这样:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
<?xml version="1.0" encoding="utf-8"?>
<android.support.design.widget.CoordinatorLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity"
>
<android.support.design.widget.AppBarLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:theme="@style/AppTheme.AppBarOverlay">
<android.support.v7.widget.Toolbar
android:id="@+id/toolbar"
android:layout_width="match_parent"
android:layout_height="?attr/actionBarSize"
android:background="?attr/colorPrimary"
app:popupTheme="@style/AppTheme.PopupOverlay" />
</android.support.design.widget.AppBarLayout>
<include layout="@layout/content_main" />
</android.support.design.widget.CoordinatorLayout>
|
添加文件 “content_main.xml”:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
<?xml version="1.0" encoding="utf-8"?>
<android.support.constraint.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
app:layout_behavior="@string/appbar_scrolling_view_behavior"
tools:context=".MainActivity"
tools:showIn="@layout/activity_main"
>
<Button
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:id="@+id/capture"
android:text="Take A Picture"
android:textColor="#ffffff"
android:textSize="26dp"
android:background="#83D5C4"
android:padding="5dp"
android:fontFamily="cursive"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintEnd_toEndOf="parent"
/>
</android.support.constraint.ConstraintLayout>
|
上面代码主要做的是,点击按钮后,调用外部摄像头拍摄,得到 bitmap 图像后调用分类器得到预测结果。
7. 添加 Result Activity
创建一个 Basic Activity 文件 “Result.java” 然后放入下列代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
package com.johnolafenwa.pytorchandroid;
import android.graphics.Bitmap;
import android.os.Bundle;
import android.support.v7.app.AppCompatActivity;
import android.support.v7.widget.Toolbar;
import android.widget.ImageView;
import android.widget.TextView;
public class Result extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_result);
Toolbar toolbar = findViewById(R.id.toolbar);
setSupportActionBar(toolbar);
Bitmap imageBitmap = (Bitmap) getIntent().getBundleExtra("imagedata").get("data");
String pred = getIntent().getStringExtra("pred");
ImageView imageView = findViewById(R.id.image);
imageView.setImageBitmap(imageBitmap);
TextView textView = findViewById(R.id.label);
textView.setText(pred);
}
}
|
文件 “activity_result.xml” 如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
<?xml version="1.0" encoding="utf-8"?>
<android.support.design.widget.CoordinatorLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".Result">
<android.support.design.widget.AppBarLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:theme="@style/AppTheme.AppBarOverlay">
<android.support.v7.widget.Toolbar
android:id="@+id/toolbar"
android:layout_width="match_parent"
android:layout_height="?attr/actionBarSize"
android:background="?attr/colorPrimary"
app:popupTheme="@style/AppTheme.PopupOverlay" />
</android.support.design.widget.AppBarLayout>
<include layout="@layout/content_result" />
</android.support.design.widget.CoordinatorLayout>
|
文件 “content_result.xml” 如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
<?xml version="1.0" encoding="utf-8"?>
<android.support.constraint.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
app:layout_behavior="@string/appbar_scrolling_view_behavior"
tools:context=".Result"
tools:showIn="@layout/activity_result">
<ImageView
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:adjustViewBounds="true"
android:src="@drawable/ic_launcher_background"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintBottom_toBottomOf="parent"
android:id="@+id/image"
/>
<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Hello World"
android:id="@+id/label"
android:textSize="16pt"
app:layout_constraintStart_toStartOf="@id/image"
app:layout_constraintEnd_toEndOf="@id/image"
app:layout_constraintTop_toBottomOf="@id/image"
/>
</android.support.constraint.ConstraintLayout>
|
然后差不多就完成了!
编译项目
接下来就是编译和运行自己的 Android 应用了,如下图: