第二十六章 tensorflow入门
下载和安装
ubuntu
sudo apt-get -y install -y libpng12-dev libfreetype6 python-numpy python-scipy ipython python-matplotlib build-essential cmake pkg-config libtiff5-dev libjpeg-dev libjasper-dev libgtk2.0-dev libavcodec-dev libavformat-dev libswscale-dev libv4l-dev swig zip python-sklearn python-wheel
- 安装pip以及virtualenv
$sudo apt-get install python-pip python-dev python-virtualenv
2.创建virtualenv
$virtualenv --system-site-packages
3.使能virtual环境
$source ~/tensorflow/bin/activate
升级到1.2
python2.7
(envtensorflow)gsc@X250:~/envtensorflow/lib$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.2.0-cp27-none-linux_x86_64.whl
python3(全局)
gsc@X250:~/envtensorflow/lib$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.2.0-cp34-cp34m-linux_x86_64.whl
[sudo] password for gsc:
Downloading/unpacking https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.2.0-cp34-cp34m-linux_x86_64.whl
Downloading tensorflow-1.2.0-cp34-cp34m-linux_x86_64.whl (34.5MB): 34.5MB downloaded
将出现如下提示符合:
(tensorflow)$
4.安装tensorflow
(tensorflow)$ pip install --upgrade tensorflow # for Python 2.7 (tensorflow)$ pip3 install --upgrade tensorflow # for Python 3.n
KWS(唤醒词识别)
训练模型:
代码可以在github上下载: kws code base
谷歌官网是个不错的起点 英文原版
python tensorflow/examples/speech_commands/train.py ----wanted_words=house
下载完成之后,运行会有如下log:
I0730 16:53:44.766740 55030 train.py:176] Training from step: 1
I0730 16:53:47.289078 55030 train.py:217] Step #1: rate 0.001000, accuracy 7.0%, cross entropy 2.611571
freeze pb文件:
python tensorflow/examples/speech_commands/freeze.py --start_checkpoint=/tmp/speech_commands_train/conv.ckpt-18000 --out_file=/tmp/my_frozen_graph.pb ----wanted_words=house
After pb file generated, move to assets directory of android, android directory is an android studio project. The file used by my project is conv_actions_house_labels.txt and my_house_frozen_graph.pb.
如果需要了解tensorflow RNN训练细节见 tensorflow RNN实例
Android apk安装
上述code base已经有apk安装文件. adb install 命令可以进行安装.
Android tensorflow API
private static final String LABEL_FILENAME = "file:///android_asset/conv_actions_house_labels.txt";
private static final String MODEL_FILENAME = "file:///android_asset/my_house_frozen_graph.pb";
private TensorFlowInferenceInterface inferenceInterface;
// Load the TensorFlow model.
inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILENAME);
// Run the model.
inferenceInterface.feed(SAMPLE_RATE_NAME, sampleRateList);
inferenceInterface.feed(INPUT_DATA_NAME, floatInputBuffer, RECORDING_LENGTH, 1);
inferenceInterface.run(outputScoresNames);
inferenceInterface.fetch(OUTPUT_SCORES_NAME, outputScores);
模型创建和加载
public TensorFlowInferenceInterface(AssetManager var1, String var2) {
this.prepareNativeRuntime();
this.modelName = var2;
this.g = new Graph();
this.sess = new Session(this.g);
this.runner = this.sess.runner();
boolean var3 = var2.startsWith("file:///android_asset/");
Object var4 = null;
try {
String var5 = var3?var2.split("file:///android_asset/")[1]:var2;
var4 = var1.open(var5);
} catch (IOException var9) {
if(var3) {
throw new RuntimeException("Failed to load model from \'" + var2 + "\'", var9);
}
try {
var4 = new FileInputStream(var2);
} catch (IOException var8) {
throw new RuntimeException("Failed to load model from \'" + var2 + "\'", var9);
}
}
try {
this.loadGraph((InputStream)var4, this.g);
((InputStream)var4).close();
Log.i("TensorFlowInferenceInterface", "Successfully loaded model from \'" + var2 + "\'");
} catch (IOException var7) {
throw new RuntimeException("Failed to load model from \'" + var2 + "\'", var7);
}
}
C语言接口API
#include <stdio.h>
#include <stdlib.h>
#include <tensorflow/c/c_api.h>
TF_Buffer* read_file(const char* file);
void free_buffer(void* data, size_t length) {
free(data);
}
int main() {
// Graph definition from unzipped https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
// which is used in the Go, Java and Android examples
TF_Buffer* graph_def = read_file("tensorflow_inception_graph.pb");
TF_Graph* graph = TF_NewGraph();
// Import graph_def into graph
TF_Status* status = TF_NewStatus();
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
TF_GraphImportGraphDef(graph, graph_def, opts, status);
TF_DeleteImportGraphDefOptions(opts);
if (TF_GetCode(status) != TF_OK) {
fprintf(stderr, "ERROR: Unable to import graph %s", TF_Message(status));
return 1;
}
fprintf(stdout, "Successfully imported graph");
TF_DeleteStatus(status);
TF_DeleteBuffer(graph_def);
// Use the graph
TF_DeleteGraph(graph);
return 0;
}
TF_Buffer* read_file(const char* file) {
FILE *f = fopen(file, "rb");
fseek(f, 0, SEEK_END);
long fsize = ftell(f);
fseek(f, 0, SEEK_SET); //same as rewind(f);
void* data = malloc(fsize);
fread(data, fsize, 1, f);
fclose(f);
TF_Buffer* buf = TF_NewBuffer();
buf->data = data;
buf->length = fsize;
buf->data_deallocator = free_buffer;
return buf;
}
tensorflow 模型文件
tensorflow生成的模型文件主要有三个:
.meta, .index和.data
分成三个文件的原因是tensorflow将计算图结构和变量值存储在不同的文件里。.meta
文件描述的是计算图结构。
freeze_graph.py脚本从一个GraphDef(.pb或者.pbtxt
)文件和checkpoint(.meta, .index, .data
)文件生成“frozen”【所谓的frozen就是将训练中的变量,转换成模型部署需要的常量过程,这个常量的保存形式是google的protocol buffers格式】的模型。protocol buffer
在memory和运算速度上比XML有优势。
frozen的过程大致如下:
1.建立和训练模型tf.Graph,记为g_1;
2.使用Session.run()
接口,以numpy数组的方式获取checkpoint的变量值;
3.在新的tf.Graph
中,使用tf.constant()
为该新的graph创建常量值,常量值的来源是第二部创建的值。
4.使用tf.import_graph_def()
将g_1的节点拷贝到g_2中,使用input_map参数将g_1中的参数存放到tf.constant()
张量中。
5.使用g_2.as_graph_def()
获得图的protocol buffer
的表示方法。