📌  相关文章
📜  如何使用 TensorFlow 为 Android 创建自定义模型?

📅  最后修改于: 2022-05-13 01:55:27.227000             🧑  作者: Mango

如何使用 TensorFlow 为 Android 创建自定义模型?

Tensorflow 是一个用于机器学习的开源库。在android中,我们的计算能力和资源都是有限的。因此,我们正在使用 TensorFlow light,它专门设计用于在功率有限的设备上运行。在这篇文章中,我们将看到一个名为 iris 数据集的分类示例。该数据集包含 3 个类,每个类 50 个实例,其中每个类指的是鸢尾植物的类型。

属性信息:

  1. 萼片长度厘米
  2. 萼片宽度厘米
  3. 花瓣长度厘米
  4. 花瓣宽度厘米

根据输入中给出的信息,我们将预测植物是Iris SetosaIris Versicolour还是Iris Virginica 。您可以参考此链接了解更多信息。

分步实施

第1步:

从此 ( https://archive.ics.uci.edu/ml/machine-learning-databases/iris/ ) 链接下载 iris 数据集(文件名:iris.data )。



第2步:

在 Jupyter 笔记本中创建一个名为 iris 的新Python文件。将 iris.data 文件放在 iris.ipynb 所在的同一目录中。将以下代码复制到 Jupyter 笔记本文件中。

iris.ipynb
Python
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import LabelEncoder
from keras.utils import to_categorical
 
# reading the csb into data frame
df = pd.read_csv('iris.data')
 
# specifying the columns valus into x and y variable
# iloc range based selecting 0 to 4 (4) values
X = df.iloc[:, :4].values
y = df.iloc[:, 4].values
 
# normalizing labels
le = LabelEncoder()
 
# performing fit and transform data on y
y = le.fit_transform(y)
 
y = to_categorical(y)
 
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
 
model = Sequential()
 
# input layer
# passing number neurons =64
# relu activation
# shape of neuron 4
model.add(Dense(64, activation='relu', input_shape=[4]))
 
# processing layer
# adding another denser layer of size 64
model.add(Dense(64))
 
# creating 3 output neuron
model.add(Dense(3, activation='softmax'))
 
 
# compiling model
model.compile(optimizer='sgd', loss='categorical_crossentropy',
              metrics=['acc'])
 
# training the model for fixed number of iterations (epoches)
model.fit(X, y, epochs=200)
 
from tensorflow import lite
converter = lite.TFLiteConverter.from_keras_model(model)
 
tfmodel = converter.convert()
 
open('iris.tflite', 'wb').write(tfmodel)


XML


 
    
 
        
           
            
            
 
            
 
            
 
            
 
            
            
            


Kotlin
import androidx.appcompat.app.AppCompatActivity
import android.os.Bundle
import android.view.View
import android.widget.Button
import android.widget.EditText
import android.widget.TextView
import com.example.gfgtfdemo.ml.Iris
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.nio.ByteBuffer
 
class MainActivity : AppCompatActivity() {
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
 
        // getting the object edit texts
        var ed1: EditText = findViewById(R.id.tf1);
        var ed2: EditText = findViewById(R.id.tf2);
        var ed3: EditText = findViewById(R.id.tf3);
        var ed4: EditText = findViewById(R.id.tf4);
       
        // getting the object of result textview
        var txtView: TextView = findViewById(R.id.textView);
        var b: Button = findViewById


第 3 步:

执行open('iris.tflite','wb').write(tfmodel)行后,将在 iris.data 所在的同一目录中创建一个名为iris.tflite的新文件。

A)打开 Android Studio。创建一个新的 kotlin-android 项目。 (您可以参考此处创建项目)。

B)右键单击应用程序 > 新建 > 其他 > TensorFlow Lite 模型



C)单击文件夹图标。

D)导航到 iris.tflite 文件

E)点击确定

F)单击完成后,您的模型将如下所示。 (加载可能需要一些时间)。

将代码复制粘贴到MainActivity.kt中某个按钮的点击监听器中(如下图)。

步骤 5:为预测创建 XML 布局



导航到app > res > layout > activity_main.xml并将以下代码添加到该文件中。下面是activity_main.xml文件的代码。

XML



 
    
 
        
           
            
            
 
            
 
            
 
            
 
            
            
            


步骤 6:使用MainActivity.kt 文件

转到MainActivity.kt文件并参考以下代码。下面是MainActivity.kt文件的代码。代码中添加了注释以更详细地理解代码。

科特林

import androidx.appcompat.app.AppCompatActivity
import android.os.Bundle
import android.view.View
import android.widget.Button
import android.widget.EditText
import android.widget.TextView
import com.example.gfgtfdemo.ml.Iris
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.nio.ByteBuffer
 
class MainActivity : AppCompatActivity() {
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
 
        // getting the object edit texts
        var ed1: EditText = findViewById(R.id.tf1);
        var ed2: EditText = findViewById(R.id.tf2);
        var ed3: EditText = findViewById(R.id.tf3);
        var ed4: EditText = findViewById(R.id.tf4);
       
        // getting the object of result textview
        var txtView: TextView = findViewById(R.id.textView);
        var b: Button = findViewById


输出:

你可以从这里下载这个项目。