blob: f598f8e698f8e9913b5d7022a4fbc1c0e97817a2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tvm.android.androidcamerademo;
import android.annotation.SuppressLint;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Matrix;
import android.media.Image;
import android.os.AsyncTask;
import android.os.Bundle;
import android.os.SystemClock;
import android.util.Log;
import android.util.Size;
import android.view.LayoutInflater;
import android.view.View;
import android.view.ViewGroup;
import android.widget.ArrayAdapter;
import android.widget.ListView;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.appcompat.widget.AppCompatTextView;
import androidx.camera.core.Camera;
import androidx.camera.core.CameraSelector;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageProxy;
import androidx.camera.core.Preview;
import androidx.camera.lifecycle.ProcessCameraProvider;
import androidx.camera.view.PreviewView;
import androidx.core.content.ContextCompat;
import androidx.fragment.app.Fragment;
import androidx.renderscript.Allocation;
import androidx.renderscript.Element;
import androidx.renderscript.RenderScript;
import androidx.renderscript.Script;
import androidx.renderscript.Type;
import com.google.common.util.concurrent.ListenableFuture;
import org.apache.tvm.Function;
import org.apache.tvm.Module;
import org.apache.tvm.NDArray;
import org.apache.tvm.TVMContext;
import org.apache.tvm.TVMType;
import org.apache.tvm.TVMValue;
import org.json.JSONException;
import org.json.JSONObject;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.PriorityQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
public class Camera2BasicFragment extends Fragment {
private static final String TAG = Camera2BasicFragment.class.getSimpleName();
// TVM constants
private static final int OUTPUT_INDEX = 0;
private static final int IMG_CHANNEL = 3;
private static final boolean EXE_GPU = false;
private static final int MODEL_INPUT_SIZE = 224;
private static final String MODEL_CL_LIB_FILE = "deploy_lib_opencl.so";
private static final String MODEL_CPU_LIB_FILE = "deploy_lib_cpu.so";
private static final String MODEL_GRAPH_FILE = "deploy_graph.json";
private static final String MODEL_PARAM_FILE = "deploy_param.params";
private static final String MODEL_LABEL_FILE = "image_net_labels.json";
private static final String MODELS = "models";
private static String INPUT_NAME = "input_1";
private static String[] models;
private static String mCurModel = "";
private final float[] mCHW = new float[MODEL_INPUT_SIZE * MODEL_INPUT_SIZE * IMG_CHANNEL];
private final float[] mCHW2 = new float[MODEL_INPUT_SIZE * MODEL_INPUT_SIZE * IMG_CHANNEL];
private final Semaphore isProcessingDone = new Semaphore(1);
private final ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(
3,
3,
1,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>()
);
// rs creation just for demo. Create rs just once in onCreate and use it again.
private RenderScript rs;
private ScriptC_yuv420888 mYuv420;
private boolean mRunClassifier = false;
private AppCompatTextView mResultView;
private AppCompatTextView mInfoView;
private ListView mModelView;
private AssetManager assetManager;
private Module graphRuntimeModule;
private JSONObject labels;
private ListenableFuture<ProcessCameraProvider> cameraProviderFuture;
private PreviewView previewView;
private ImageAnalysis imageAnalysis;
static Camera2BasicFragment newInstance() {
return new Camera2BasicFragment();
}
private static Matrix getTransformationMatrix(
final int srcWidth,
final int srcHeight,
final int dstWidth,
final int dstHeight,
final int applyRotation,
final boolean maintainAspectRatio) {
final Matrix matrix = new Matrix();
if (applyRotation != 0) {
if (applyRotation % 90 != 0) {
Log.w(TAG, "Rotation of %d % 90 != 0 " + applyRotation);
}
// Translate so center of image is at origin.
matrix.postTranslate(-srcWidth / 2.0f, -srcHeight / 2.0f);
// Rotate around origin.
matrix.postRotate(applyRotation);
}
// Account for the already applied rotation, if any, and then determine how
// much scaling is needed for each axis.
final boolean transpose = (Math.abs(applyRotation) + 90) % 180 == 0;
final int inWidth = transpose ? srcHeight : srcWidth;
final int inHeight = transpose ? srcWidth : srcHeight;
// Apply scaling if necessary.
if (inWidth != dstWidth || inHeight != dstHeight) {
final float scaleFactorX = dstWidth / (float) inWidth;
final float scaleFactorY = dstHeight / (float) inHeight;
if (maintainAspectRatio) {
// Scale by minimum factor so that dst is filled completely while
// maintaining the aspect ratio. Some image may fall off the edge.
final float scaleFactor = Math.max(scaleFactorX, scaleFactorY);
matrix.postScale(scaleFactor, scaleFactor);
} else {
// Scale exactly to fill dst from src.
matrix.postScale(scaleFactorX, scaleFactorY);
}
}
if (applyRotation != 0) {
// Translate back from origin centered reference to destination frame.
matrix.postTranslate(dstWidth / 2.0f, dstHeight / 2.0f);
}
return matrix;
}
private String[] getModels() {
String[] models;
try {
models = getActivity().getAssets().list(MODELS);
} catch (IOException e) {
return null;
}
return models;
}
@SuppressLint("DefaultLocale")
private String[] inference(float[] chw) {
NDArray inputNdArray = NDArray.empty(new long[]{1, IMG_CHANNEL, MODEL_INPUT_SIZE, MODEL_INPUT_SIZE}, new TVMType("float32"));
inputNdArray.copyFrom(chw);
Function setInputFunc = graphRuntimeModule.getFunction("set_input");
setInputFunc.pushArg(INPUT_NAME).pushArg(inputNdArray).invoke();
// release tvm local variables
inputNdArray.release();
setInputFunc.release();
// get the function from the module(run it)
Function runFunc = graphRuntimeModule.getFunction("run");
runFunc.invoke();
// release tvm local variables
runFunc.release();
// get the function from the module(get output data)
NDArray outputNdArray = NDArray.empty(new long[]{1, 1000}, new TVMType("float32"));
Function getOutputFunc = graphRuntimeModule.getFunction("get_output");
getOutputFunc.pushArg(OUTPUT_INDEX).pushArg(outputNdArray).invoke();
float[] output = outputNdArray.asFloatArray();
// release tvm local variables
outputNdArray.release();
getOutputFunc.release();
if (null != output) {
String[] results = new String[5];
// top-5
PriorityQueue<Integer> pq = new PriorityQueue<>(1000, (Integer idx1, Integer idx2) -> output[idx1] > output[idx2] ? -1 : 1);
// display the result from extracted output data
for (int j = 0; j < output.length; ++j) {
pq.add(j);
}
for (int l = 0; l < 5; l++) {
//noinspection ConstantConditions
int idx = pq.poll();
if (idx < labels.length()) {
try {
results[l] = String.format("%.2f", output[idx]) + " : " + labels.getString(Integer.toString(idx));
} catch (JSONException e) {
Log.e(TAG, "index out of range", e);
}
} else {
results[l] = "???: unknown";
}
}
return results;
}
return new String[5];
}
private void updateActiveModel() {
Log.i(TAG, "updating active model...");
new LoadModelAsyncTask().execute();
}
@Override
public void onViewCreated(final View view, Bundle savedInstanceState) {
mResultView = view.findViewById(R.id.resultTextView);
mInfoView = view.findViewById(R.id.infoTextView);
mModelView = view.findViewById(R.id.modelListView);
if (assetManager == null) {
assetManager = getActivity().getAssets();
}
mModelView.setChoiceMode(ListView.CHOICE_MODE_SINGLE);
models = getModels();
ArrayAdapter<String> modelAdapter =
new ArrayAdapter<>(
getContext(), R.layout.listview_row, R.id.listview_row_text, models);
mModelView.setAdapter(modelAdapter);
mModelView.setItemChecked(0, true);
mModelView.setOnItemClickListener(
(parent, view1, position, id) -> updateActiveModel());
new LoadModelAsyncTask().execute();
}
@Override
public void onActivityCreated(Bundle savedInstanceState) {
super.onActivityCreated(savedInstanceState);
}
@Override
public void onDestroy() {
// release tvm local variables
if (null != graphRuntimeModule)
graphRuntimeModule.release();
super.onDestroy();
}
/**
* Read file from assets and return byte array.
*
* @param assets The asset manager to be used to load assets.
* @param fileName The filepath of read file.
* @return byte[] file content
* @throws IOException
*/
private byte[] getBytesFromFile(AssetManager assets, String fileName) throws IOException {
InputStream is = assets.open(fileName);
int length = is.available();
byte[] bytes = new byte[length];
// Read in the bytes
int offset = 0;
int numRead;
try {
while (offset < bytes.length
&& (numRead = is.read(bytes, offset, bytes.length - offset)) >= 0) {
offset += numRead;
}
} finally {
is.close();
}
// Ensure all the bytes have been read in
if (offset < bytes.length) {
throw new IOException("Could not completely read file " + fileName);
}
return bytes;
}
/**
* Get application cache path where to place compiled functions.
*
* @param fileName library file name.
* @return String application cache folder path
* @throws IOException
*/
private String getTempLibFilePath(String fileName) throws IOException {
File tempDir = File.createTempFile("tvm4j_demo_", "");
if (!tempDir.delete() || !tempDir.mkdir()) {
throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
}
return (tempDir + File.separator + fileName);
}
private Bitmap YUV_420_888_toRGB(Image image, int width, int height) {
// Get the three image planes
Image.Plane[] planes = image.getPlanes();
ByteBuffer buffer = planes[0].getBuffer();
byte[] y = new byte[buffer.remaining()];
buffer.get(y);
buffer = planes[1].getBuffer();
byte[] u = new byte[buffer.remaining()];
buffer.get(u);
buffer = planes[2].getBuffer();
byte[] v = new byte[buffer.remaining()];
buffer.get(v);
int yRowStride = planes[0].getRowStride();
int uvRowStride = planes[1].getRowStride();
int uvPixelStride = planes[1].getPixelStride();
// Y,U,V are defined as global allocations, the out-Allocation is the Bitmap.
// Note also that uAlloc and vAlloc are 1-dimensional while yAlloc is 2-dimensional.
Type.Builder typeUcharY = new Type.Builder(rs, Element.U8(rs));
typeUcharY.setX(yRowStride).setY(height);
Allocation yAlloc = Allocation.createTyped(rs, typeUcharY.create());
yAlloc.copyFrom(y);
mYuv420.set_ypsIn(yAlloc);
Type.Builder typeUcharUV = new Type.Builder(rs, Element.U8(rs));
// note that the size of the u's and v's are as follows:
// ( (width/2)*PixelStride + padding ) * (height/2)
// = (RowStride ) * (height/2)
typeUcharUV.setX(u.length);
Allocation uAlloc = Allocation.createTyped(rs, typeUcharUV.create());
uAlloc.copyFrom(u);
mYuv420.set_uIn(uAlloc);
Allocation vAlloc = Allocation.createTyped(rs, typeUcharUV.create());
vAlloc.copyFrom(v);
mYuv420.set_vIn(vAlloc);
// handover parameters
mYuv420.set_picWidth(width);
mYuv420.set_uvRowStride(uvRowStride);
mYuv420.set_uvPixelStride(uvPixelStride);
Bitmap outBitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
Allocation outAlloc = Allocation.createFromBitmap(rs, outBitmap, Allocation.MipmapControl.MIPMAP_NONE, Allocation.USAGE_SCRIPT);
Script.LaunchOptions lo = new Script.LaunchOptions();
lo.setX(0, width); // by this we ignore the y’s padding zone, i.e. the right side of x between width and yRowStride
lo.setY(0, height);
mYuv420.forEach_doConvert(outAlloc, lo);
outAlloc.copyTo(outBitmap);
return outBitmap;
}
private float[] getFrame(ImageProxy imageProxy) {
@SuppressLint("UnsafeExperimentalUsageError")
Image image = imageProxy.getImage();
// extract the jpeg content
if (image == null) {
return null;
}
Bitmap imageBitmap = YUV_420_888_toRGB(image, image.getWidth(), image.getHeight());
imageProxy.close();
// crop input image at centre to model input size
Bitmap cropImageBitmap = Bitmap.createBitmap(MODEL_INPUT_SIZE, MODEL_INPUT_SIZE, Bitmap.Config.ARGB_8888);
Matrix frameToCropTransform = getTransformationMatrix(imageBitmap.getWidth(), imageBitmap.getHeight(),
MODEL_INPUT_SIZE, MODEL_INPUT_SIZE, 0, true);
Canvas canvas = new Canvas(cropImageBitmap);
canvas.drawBitmap(imageBitmap, frameToCropTransform, null);
// image pixel int values
int[] pixelValues = new int[MODEL_INPUT_SIZE * MODEL_INPUT_SIZE];
// image RGB float values
// pre-process the image data from 0-255 int to normalized float based on the
// provided parameters.
cropImageBitmap.getPixels(pixelValues, 0, MODEL_INPUT_SIZE, 0, 0, MODEL_INPUT_SIZE, MODEL_INPUT_SIZE);
for (int j = 0; j < pixelValues.length; ++j) {
mCHW2[j * 3 + 0] = ((pixelValues[j] >> 16) & 0xFF) / 255.0f;
mCHW2[j * 3 + 1] = ((pixelValues[j] >> 8) & 0xFF) / 255.0f;
mCHW2[j * 3 + 2] = (pixelValues[j] & 0xFF) / 255.0f;
}
// pre-process the image rgb data transpose based on the provided parameters.
for (int k = 0; k < IMG_CHANNEL; ++k) {
for (int l = 0; l < MODEL_INPUT_SIZE; ++l) {
for (int m = 0; m < MODEL_INPUT_SIZE; ++m) {
int dst_index = m + MODEL_INPUT_SIZE * l + MODEL_INPUT_SIZE * MODEL_INPUT_SIZE * k;
int src_index = k + IMG_CHANNEL * m + IMG_CHANNEL * MODEL_INPUT_SIZE * l;
mCHW[dst_index] = mCHW2[src_index];
}
}
}
return mCHW;
}
@Override
public void onCreate(@Nullable Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
cameraProviderFuture = ProcessCameraProvider.getInstance(getActivity());
}
@SuppressLint({"RestrictedApi", "UnsafeExperimentalUsageError"})
@Nullable
@Override
public View onCreateView(@NonNull LayoutInflater inflater, @Nullable ViewGroup container, @Nullable Bundle savedInstanceState) {
View v = inflater.inflate(R.layout.fragment_camera2_basic, container, false);
previewView = v.findViewById(R.id.textureView);
rs = RenderScript.create(getActivity());
mYuv420 = new ScriptC_yuv420888(rs);
cameraProviderFuture.addListener(() -> {
try {
ProcessCameraProvider cameraProvider = cameraProviderFuture.get();
bindPreview(cameraProvider);
} catch (ExecutionException | InterruptedException e) {
// No errors need to be handled for this Future. This should never be reached
}
}, ContextCompat.getMainExecutor(getActivity()));
imageAnalysis = new ImageAnalysis.Builder()
.setTargetResolution(new Size(224, 224))
.setMaxResolution(new Size(300, 300))
.setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
.build();
imageAnalysis.setAnalyzer(threadPoolExecutor, image -> {
Log.e(TAG, "w: " + image.getWidth() + " h: " + image.getHeight());
if (mRunClassifier && isProcessingDone.tryAcquire()) {
long t1 = SystemClock.uptimeMillis();
//float[] chw = getFrame(image);
//float[] chw = YUV_420_888_toRGBPixels(image);
float[] chw = getFrame(image);
if (chw != null) {
long t2 = SystemClock.uptimeMillis();
String[] results = inference(chw);
long t3 = SystemClock.uptimeMillis();
StringBuilder msgBuilder = new StringBuilder();
for (int l = 1; l < 5; l++) {
msgBuilder.append(results[l]).append("\n");
}
String msg = msgBuilder.toString();
msg += "getFrame(): " + (t2 - t1) + "ms" + "\n";
msg += "inference(): " + (t3 - t2) + "ms" + "\n";
String finalMsg = msg;
this.getActivity().runOnUiThread(() -> {
mResultView.setText(String.format("model: %s \n %s", mCurModel, results[0]));
mInfoView.setText(finalMsg);
});
}
isProcessingDone.release();
}
image.close();
});
return v;
}
private void bindPreview(@NonNull ProcessCameraProvider cameraProvider) {
@SuppressLint("RestrictedApi") Preview preview = new Preview.Builder()
.setMaxResolution(new Size(800, 800))
.setTargetName("Preview")
.build();
preview.setSurfaceProvider(previewView.getPreviewSurfaceProvider());
CameraSelector cameraSelector =
new CameraSelector.Builder()
.requireLensFacing(CameraSelector.LENS_FACING_BACK)
.build();
Camera camera = cameraProvider.bindToLifecycle(this, cameraSelector, preview, imageAnalysis);
}
@Override
public void onDestroyView() {
threadPoolExecutor.shutdownNow();
super.onDestroyView();
}
private void setInputName(String modelName) {
if (modelName.equals("mobilenet_v2")) {
INPUT_NAME = "input_1";
} else if (modelName.equals("resnet18_v1")) {
INPUT_NAME = "data";
} else {
throw new RuntimeException("Model input may not be right. Please set INPUT_NAME here explicitly.");
}
}
/*
Load precompiled model on TVM graph runtime and init the system.
*/
private class LoadModelAsyncTask extends AsyncTask<Void, Void, Integer> {
@Override
protected Integer doInBackground(Void... args) {
mRunClassifier = false;
// load synset name
int modelIndex = mModelView.getCheckedItemPosition();
setInputName(models[modelIndex]);
String model = MODELS + "/" + models[modelIndex];
String labelFilename = MODEL_LABEL_FILE;
Log.i(TAG, "Reading labels from: " + model + "/" + labelFilename);
try {
labels = new JSONObject(new String(getBytesFromFile(assetManager, model + "/" + labelFilename)));
} catch (IOException | JSONException e) {
Log.e(TAG, "Problem reading labels name file!", e);
return -1;//failure
}
// load json graph
String modelGraph;
String graphFilename = MODEL_GRAPH_FILE;
Log.i(TAG, "Reading json graph from: " + model + "/" + graphFilename);
try {
modelGraph = new String(getBytesFromFile(assetManager, model + "/" + graphFilename));
} catch (IOException e) {
Log.e(TAG, "Problem reading json graph file!", e);
return -1;//failure
}
// upload tvm compiled function on application cache folder
String libCacheFilePath;
String libFilename = EXE_GPU ? MODEL_CL_LIB_FILE : MODEL_CPU_LIB_FILE;
Log.i(TAG, "Uploading compiled function to cache folder");
try {
libCacheFilePath = getTempLibFilePath(libFilename);
byte[] modelLibByte = getBytesFromFile(assetManager, model + "/" + libFilename);
FileOutputStream fos = new FileOutputStream(libCacheFilePath);
fos.write(modelLibByte);
fos.close();
} catch (IOException e) {
Log.e(TAG, "Problem uploading compiled function!", e);
return -1;//failure
}
// load parameters
byte[] modelParams;
try {
modelParams = getBytesFromFile(assetManager, model + "/" + MODEL_PARAM_FILE);
} catch (IOException e) {
Log.e(TAG, "Problem reading params file!", e);
return -1;//failure
}
Log.i(TAG, "creating java tvm context...");
// create java tvm context
TVMContext tvmCtx = EXE_GPU ? TVMContext.opencl() : TVMContext.cpu();
Log.i(TAG, "loading compiled functions...");
Log.i(TAG, libCacheFilePath);
// tvm module for compiled functions
Module modelLib = Module.load(libCacheFilePath);
// get global function module for graph runtime
Log.i(TAG, "getting graph runtime create handle...");
Function runtimeCreFun = Function.getFunction("tvm.graph_runtime.create");
Log.i(TAG, "creating graph runtime...");
Log.i(TAG, "ctx type: " + tvmCtx.deviceType);
Log.i(TAG, "ctx id: " + tvmCtx.deviceId);
TVMValue runtimeCreFunRes = runtimeCreFun.pushArg(modelGraph)
.pushArg(modelLib)
.pushArg(tvmCtx.deviceType)
.pushArg(tvmCtx.deviceId)
.invoke();
Log.i(TAG, "as module...");
graphRuntimeModule = runtimeCreFunRes.asModule();
Log.i(TAG, "getting graph runtime load params handle...");
// get the function from the module(load parameters)
Function loadParamFunc = graphRuntimeModule.getFunction("load_params");
Log.i(TAG, "loading params...");
loadParamFunc.pushArg(modelParams).invoke();
// release tvm local variables
modelLib.release();
loadParamFunc.release();
runtimeCreFun.release();
mCurModel = model;
mRunClassifier = true;
return 0;//success
}
}
}