第二十六章 tensorflow入门

下载和安装

ubuntu

sudo apt-get -y install -y libpng12-dev libfreetype6 python-numpy python-scipy ipython python-matplotlib build-essential cmake pkg-config libtiff5-dev libjpeg-dev libjasper-dev libgtk2.0-dev libavcodec-dev libavformat-dev libswscale-dev libv4l-dev swig zip python-sklearn python-wheel
  1. 安装pip以及virtualenv
    $sudo apt-get install python-pip python-dev python-virtualenv
    

2.创建virtualenv

$virtualenv --system-site-packages

3.使能virtual环境

 $source ~/tensorflow/bin/activate

升级到1.2

python2.7
(envtensorflow)gsc@X250:~/envtensorflow/lib$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.2.0-cp27-none-linux_x86_64.whl

python3(全局)
gsc@X250:~/envtensorflow/lib$ sudo pip3 install --upgrade  https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.2.0-cp34-cp34m-linux_x86_64.whl
[sudo] password for gsc: 
Downloading/unpacking https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.2.0-cp34-cp34m-linux_x86_64.whl
  Downloading tensorflow-1.2.0-cp34-cp34m-linux_x86_64.whl (34.5MB): 34.5MB downloaded

将出现如下提示符合:

(tensorflow)$

4.安装tensorflow

(tensorflow)$ pip install --upgrade tensorflow # for Python 2.7 (tensorflow)$ pip3 install --upgrade tensorflow # for Python 3.n

KWS(唤醒词识别)

训练模型:

代码可以在github上下载: kws code base

谷歌官网是个不错的起点 英文原版

python tensorflow/examples/speech_commands/train.py ----wanted_words=house

下载完成之后,运行会有如下log:

I0730 16:53:44.766740   55030 train.py:176] Training from step: 1
I0730 16:53:47.289078   55030 train.py:217] Step #1: rate 0.001000, accuracy 7.0%, cross entropy 2.611571

freeze pb文件:

python tensorflow/examples/speech_commands/freeze.py --start_checkpoint=/tmp/speech_commands_train/conv.ckpt-18000 --out_file=/tmp/my_frozen_graph.pb ----wanted_words=house

After pb file generated, move to assets directory of android, android directory is an android studio project. The file used by my project is conv_actions_house_labels.txt and my_house_frozen_graph.pb.

如果需要了解tensorflow RNN训练细节见 tensorflow RNN实例

Android apk安装

上述code base已经有apk安装文件. adb install 命令可以进行安装.

Android tensorflow API

private static final String LABEL_FILENAME = "file:///android_asset/conv_actions_house_labels.txt";
  private static final String MODEL_FILENAME = "file:///android_asset/my_house_frozen_graph.pb";
  private TensorFlowInferenceInterface inferenceInterface;

    // Load the TensorFlow model.
    inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILENAME);

      // Run the model.
      inferenceInterface.feed(SAMPLE_RATE_NAME, sampleRateList);
      inferenceInterface.feed(INPUT_DATA_NAME, floatInputBuffer, RECORDING_LENGTH, 1);
      inferenceInterface.run(outputScoresNames);
      inferenceInterface.fetch(OUTPUT_SCORES_NAME, outputScores);

模型创建和加载

public TensorFlowInferenceInterface(AssetManager var1, String var2) {
        this.prepareNativeRuntime();
        this.modelName = var2;
        this.g = new Graph();
        this.sess = new Session(this.g);
        this.runner = this.sess.runner();
        boolean var3 = var2.startsWith("file:///android_asset/");
        Object var4 = null;

        try {
            String var5 = var3?var2.split("file:///android_asset/")[1]:var2;
            var4 = var1.open(var5);
        } catch (IOException var9) {
            if(var3) {
                throw new RuntimeException("Failed to load model from \'" + var2 + "\'", var9);
            }

            try {
                var4 = new FileInputStream(var2);
            } catch (IOException var8) {
                throw new RuntimeException("Failed to load model from \'" + var2 + "\'", var9);
            }
        }

        try {
            this.loadGraph((InputStream)var4, this.g);
            ((InputStream)var4).close();
            Log.i("TensorFlowInferenceInterface", "Successfully loaded model from \'" + var2 + "\'");
        } catch (IOException var7) {
            throw new RuntimeException("Failed to load model from \'" + var2 + "\'", var7);
        }
    }

C语言接口API

 #include <stdio.h>                                                                        
#include <stdlib.h>                                                                       
#include <tensorflow/c/c_api.h>                                                           

TF_Buffer* read_file(const char* file);                                                   

void free_buffer(void* data, size_t length) {                                             
        free(data);                                                                       
}                                                                                         

int main() {                                                                              
  // Graph definition from unzipped https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
  // which is used in the Go, Java and Android examples                                   
  TF_Buffer* graph_def = read_file("tensorflow_inception_graph.pb");                      
  TF_Graph* graph = TF_NewGraph();

  // Import graph_def into graph                                                          
  TF_Status* status = TF_NewStatus();                                                     
  TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();                         
  TF_GraphImportGraphDef(graph, graph_def, opts, status);
  TF_DeleteImportGraphDefOptions(opts);
  if (TF_GetCode(status) != TF_OK) {
          fprintf(stderr, "ERROR: Unable to import graph %s", TF_Message(status));        
          return 1;
  }       
  fprintf(stdout, "Successfully imported graph");                                         
  TF_DeleteStatus(status);
  TF_DeleteBuffer(graph_def);                                                             

  // Use the graph                                                                        
  TF_DeleteGraph(graph);                                                                  
  return 0;
} 

TF_Buffer* read_file(const char* file) {                                                  
  FILE *f = fopen(file, "rb");
  fseek(f, 0, SEEK_END);
  long fsize = ftell(f);                                                                  
  fseek(f, 0, SEEK_SET);  //same as rewind(f);                                            

  void* data = malloc(fsize);                                                             
  fread(data, fsize, 1, f);
  fclose(f);

  TF_Buffer* buf = TF_NewBuffer();                                                        
  buf->data = data;
  buf->length = fsize;                                                                    
  buf->data_deallocator = free_buffer;                                                    
  return buf;
}

tensorflow 模型文件

tensorflow生成的模型文件主要有三个:

.meta, .index和.data

分成三个文件的原因是tensorflow将计算图结构和变量值存储在不同的文件里。.meta文件描述的是计算图结构。 freeze_graph.py脚本从一个GraphDef(.pb或者.pbtxt)文件和checkpoint(.meta, .index, .data)文件生成“frozen”【所谓的frozen就是将训练中的变量,转换成模型部署需要的常量过程,这个常量的保存形式是google的protocol buffers格式】的模型。protocol buffer在memory和运算速度上比XML有优势。 frozen的过程大致如下: 1.建立和训练模型tf.Graph,记为g_1; 2.使用Session.run()接口,以numpy数组的方式获取checkpoint的变量值; 3.在新的tf.Graph中,使用tf.constant()为该新的graph创建常量值,常量值的来源是第二部创建的值。 4.使用tf.import_graph_def()将g_1的节点拷贝到g_2中,使用input_map参数将g_1中的参数存放到tf.constant()张量中。 5.使用g_2.as_graph_def()获得图的protocol buffer的表示方法。

results matching ""

    No results matching ""