Git Product home page Git Product logo

tc-resnet's Introduction

Temporal Convolution for Real-time Keyword Spotting on Mobile Devices

Abstract

Keyword spotting (KWS) plays a critical role in enabling speech-based user interactions on smart devices. Recent developments in the field of deep learning have led to wide adoption of convolutional neural networks (CNNs) in KWS systems due to their exceptional accuracy and robustness. The main challenge faced by KWS systems is the trade-off between high accuracy and low latency. Unfortunately, there has been little quantitative analysis of the actual latency of KWS models on mobile devices. This is especially concerning since conventional convolution-based KWS approaches are known to require a large number of operations to attain an adequate level of performance.

In this paper, we propose a temporal convolution for real-time KWS on mobile devices. Unlike most of the 2D convolution-based KWS approaches that require a deep architecture to fully capture both low- and high-frequency domains, we exploit temporal convolutions with a compact ResNet architecture. In Google Speech Command Dataset, we achieve more than 385x speedup on Google Pixel 1 and surpass the accuracy compared to the state-of-the-art model. In addition, we release the implementation of the proposed and the baseline models including an end-to-end pipeline for training models and evaluating them on mobile devices.

Requirements

  • Python 3.6+
  • Tensorflow 1.13.1

Installation

git clone https://github.com/hyperconnect/TC-ResNet.git
pip3 install -r requirements/py36-[gpu|cpu].txt

Dataset

For evaluating the proposed and the baseline models we use Google Speech Commands Dataset.

Google Speech Commands Dataset

Follow instructions in speech_commands_dataset/

How to run

Scripts to reproduce the training and evaluation procedures discussed in the paper are located on scripts/commands. After training a model, you can generate .tflite file by following the instruction below.

To train TCResNet8Model-1.0 model, run:

./scripts/commands/TCResNet8Model-1.0_mfcc_40_3010_0.001_mom_l1.sh

To freeze the trained model checkpoint into .pb file, run:

python freeze.py --checkpoint_path work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1/TCResNet8Model-XXX --output_name output/softmax --output_type softmax --preprocess_method no_preprocessing --height 49 --width 40 --channels 1 --num_classes 12 TCResNet8Model --width_multiplier 1.0

To convert the .pb file into .tflite file, run:

tflite_convert --graph_def_file=work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1/TCResNet8Model-XXX.pb --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --output_file=work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1/TCResNet8Model-XXX.tflite --inference_type=FLOAT --inference_input_type=FLOAT --input_arrays=input --output_arrays=output/softmax --allow_custom_ops

As shown in above commands, you need to properly set height, width, model, model specific arguments(e.g. width_multiplier). For more information, please refer to scripts/commands/

Benchmark tool

Android Debug Bridge (adb) is required to run the Android benchmark tool (model/tflite_tools/run_benchmark.sh). adb is part of The Android SDK Platform Tools and you can download it here and follow the installation instructions.

1. Connect Android device to your computer

2. Check if connection is established

Run following command.

adb devices

You should see similar output to the one below. The ID of a device will, of course, differ.

List of devices attached
FA77M0304573	device

3. Run benchmark

Go to model/tflite_tools and place the TF Lite model you want to benchmark (e.g. mobilenet_v1_1.0_224.tflite) and execute the following command. You can pass the optional parameter, cpu_mask, to set the CPU affinity CPU affinity

./run_benchmark.sh TCResNet_14Model-1.5.tflite [cpu_mask]

If everything goes well you should see an output similar to the one below. The important measurement of this benchmark is avg=5701.96 part. The number represents the average latency of the inference measured in microseconds.

./run_benchmark.sh TCResNet_14Model-1.5_mfcc_40_3010_0.001_mom_l1.tflite 3
benchmark_model_r1.13_official: 1 file pushed. 22.1 MB/s (1265528 bytes in 0.055s)
TCResNet_14Model-1.5_mfcc_40_3010_0.001_mom_l1.tflite: 1 file pushed. 25.0 MB/s (1217136 bytes in 0.046s)
>>> run_benchmark_summary TCResNet_14Model-1.5_mfcc_40_3010_0.001_mom_l1.tflite 3
TCResNet_14Model-1.5_mfcc_40_3010_0.001_mom_l1.tflite > count=50 first=5734 curr=5801 min=4847 max=6516 avg=5701.96 std=210

License

Apache License 2.0

tc-resnet's People

Contributors

beomjunshin-ben avatar justin-hpcnt avatar manipopopo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tc-resnet's Issues

Regarding Google Speech Command dataset

Hi, thanks for the great code. I've tried to reproduce the results. However, I found two confusing issues,

  1. The original dataset seems not to include _silence_ folder, I didn't see _silence_/721f767c_nohash_2.wav listed in test.txt

  2. The dataset has 64k samples, while in the code only 29k used. Why is it so? Are the results from your paper produced from the data included in the following listed files?

$ cat test.txt  | wc -l
3081
$ cat train.txt | wc -l
22246
$ cat valid.txt | wc -l
3093

Looking forward to your reply,

Sincerely,
Bo

Desktop(GPU) Real time Inference code needed!

I have converted the checkpoint to frozen graph, now I want to inference the frozen graph at real time on GPU instead of Android. Kindly share the inference script for the same.

Thanks.

ValueError: not enough values to unpack (expected 2, got 0)

I tried to run the code as described. Only changed some of the values of arguments. I'm getting an error:

Traceback (most recent call last):
File "/home/gauri/TC-ResNet/evaluate_audio.py", line 87, in
main(args)
File "/home/gauri/TC-ResNet/evaluate_audio.py", line 28, in main
is_training,
File "/home/gauri/TC-ResNet/datasets/audio_data_wrapper.py", line 17, in init
self.setup_dataset(self.placeholders)
File "/home/gauri/TC-ResNet/datasets/data_wrapper_base.py", line 70, in setup_dataset
dataset = dataset.map(self._parse_function, num_parallel_calls=self.args.num_threads)
File "/home/gauri/anaconda2/envs/edgeml_gpu/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1584, in map
self, map_func, num_parallel_calls, preserve_cardinality=False))
File "/home/gauri/anaconda2/envs/edgeml_gpu/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2771, in init
input_dataset, map_func, use_inter_op_parallelism, preserve_cardinality)
File "/home/gauri/anaconda2/envs/edgeml_gpu/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2737, in init
map_func, self._transformation_name(), dataset=input_dataset)
File "/home/gauri/anaconda2/envs/edgeml_gpu/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2124, in init
self._function.add_to_graph(ops.get_default_graph())
File "/home/gauri/anaconda2/envs/edgeml_gpu/lib/python3.7/site-packages/tensorflow/python/framework/function.py", line 490, in add_to_graph
self._create_definition_if_needed()
File "/home/gauri/anaconda2/envs/edgeml_gpu/lib/python3.7/site-packages/tensorflow/python/framework/function.py", line 341, in _create_definition_if_needed
self._create_definition_if_needed_impl()
File "/home/gauri/anaconda2/envs/edgeml_gpu/lib/python3.7/site-packages/tensorflow/python/framework/function.py", line 355, in _create_definition_if_needed_impl
whitelisted_stateful_ops=self._whitelisted_stateful_ops)
File "/home/gauri/anaconda2/envs/edgeml_gpu/lib/python3.7/site-packages/tensorflow/python/framework/function.py", line 883, in func_graph_from_py_func
outputs = func(*func_graph.inputs)
File "/home/gauri/anaconda2/envs/edgeml_gpu/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2099, in tf_data_structured_function_wrapper
ret = func(*nested_args)
File "/home/gauri/TC-ResNet/datasets/audio_data_wrapper.py", line 53, in _parse_function
background_max_volume=self.background_max_volume,
File "/home/gauri/TC-ResNet/datasets/audio_data_wrapper.py", line 35, in augment_audio
return aug_fn(filename, desired_samples, file_format, sample_rate, **kwargs)
File "/home/gauri/TC-ResNet/datasets/augmentation_factory.py", line 187, in anchored_slice_or_pad
audio = _mix_background(audio, desired_samples, is_silent=is_silent, **kwargs)
File "/home/gauri/TC-ResNet/datasets/augmentation_factory.py", line 66, in _mix_background
}, exclusive=True)
File "/home/gauri/anaconda2/envs/edgeml_gpu/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 4092, in case
strict=strict)
File "/home/gauri/anaconda2/envs/edgeml_gpu/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3968, in _case_helper
pred_fn_pairs, exclusive, name, allow_python_preds)
File "/home/gauri/anaconda2/envs/edgeml_gpu/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3933, in _case_verify_and_canonicalize_args
predicates, actions = zip(*pred_fn_pairs)
ValueError: not enough values to unpack (expected 2, got 0)

Can anybody help me in solving the above error?
Thanks in advance.

How to get MFCC feature when I do not want to use tensorflow?

Thank you for your geat work.
I trained a TCResNet8 model and then converted to tflite-model successfully.
But I do not know how to use the tflite model when I do not use tensorflow to extract the MFCC feature of wavs.
I tried to use 'python_speech_features' to extract MFCC, but it seems like the extracted feature is totally different with the same input wav.
Could you please give me some tips on this?
Thanks a lot.

No such file or directory when training

Hi, I met tensorflow/core/framework/op_kernel.cc:1401] OP_REQUIRES failed at whole_file_read_ops.cc:114 : Not found: NewRandomAccessFile failed to Create/Open: D:\user\dataset\KWS\TC-ResNet\google_speech_commands\splitted_data\train\right\61a0d340_nohash_0.wav issue, but I can find the file in that absolute path.
I use download_and_split.sh script to split the data, and also use absolute path.
Is there anyone met this problem before?
Thanks

[INFO|tf_utils.py:56] 2023-08-08 16:19:53,935 > ?[36m>>    TCResNet8/block2/conv2_1/BatchNorm/beta:0 <dtype: 'float32_ref'> : [48], 48 ... 63280 (is_trainable: True)?[0m
[INFO|tf_utils.py:56] 2023-08-08 16:19:53,936 > ?[32m>>    TCResNet8/block2/conv2_1/BatchNorm/moving_mean:0 <dtype: 'float32_ref'> : [48], 48 ... 63328 (is_trainable: False)?[0m
[INFO|tf_utils.py:56] 2023-08-08 16:19:53,936 > ?[32m>>    TCResNet8/block2/conv2_1/BatchNorm/moving_variance:0 <dtype: 'float32_ref'> : [48], 48 ... 63376 (is_trainable: False)?[0m
[INFO|tf_utils.py:56] 2023-08-08 16:19:53,936 > ?[36m>>    TCResNet8/fc/weights:0 <dtype: 'float32_ref'> : [1, 1, 48, 22], 1056 ... 64432 (is_trainable: True)?[0m
[INFO|tf_utils.py:56] 2023-08-08 16:19:53,936 > ?[36m>>    TCResNet8/fc2/weights:0 <dtype: 'float32_ref'> : [1, 1, 48, 2], 96 ... 64528 (is_trainable: True)?[0m
[INFO|tf_utils.py:61] 2023-08-08 16:19:53,936 > ?[1m?[36m>> End of showing all variables // Number of variables: 52, Number of trainable variables : 32, Total prod + sum of shape: 64528 (63872 trainable)?[0m
[INFO|tf_utils.py:230] 2023-08-08 16:19:53,937 > ?[1m?[33mself.args.checkpoint_path updated:  -> None?[0m
[INFO|trainer.py:190] 2023-08-08 16:19:53,964 > Use MomentumOptimizer
[INFO|trainer.py:156] 2023-08-08 16:19:54,604 > ?[4m?[1m?[36mInitialize global / local variables?[0m
[INFO|summaries.py:123] 2023-08-08 16:19:54,604 > Write summaries into : D://user/dataset/KWS/TC-ResNet/work/v1/TCResNet8Model-1.0/no_processing.001_mom_l1
[INFO|base.py:127] 2023-08-08 16:19:55,080 > MAPMetricOp is added.
[INFO|base.py:127] 2023-08-08 16:19:55,081 > AccuracyMetricOp is added.
[INFO|base.py:127] 2023-08-08 16:19:55,081 > Top5AccuracyMetricOp is added.
[INFO|base.py:127] 2023-08-08 16:19:55,081 > ClassificationReportMetricOp is added.
[INFO|base.py:127] 2023-08-08 16:19:55,081 > LossesMetricOp is added.
[INFO|base.py:127] 2023-08-08 16:19:55,081 > WavSummaryOp is added.
[INFO|base.py:127] 2023-08-08 16:19:55,081 > LearningRateSummaryOp is added.
[INFO|trainer.py:167] 2023-08-08 16:19:55,213 > ?[1m?[4m?[33mWatch Validation Through TensorBoard !?[0m
[INFO|trainer.py:169] 2023-08-08 16:19:55,225 > ?[1m?[4m?[33m--checkpoint_path D:\user\dataset\KWS\TC-ResNet\work\v1\TCResNet8Model-1.0\no_processing.001_mom_l1?[0m
[INFO|trainer.py:368] 2023-08-08 16:19:55,225 > Training started
2023-08-08 16:19:55.861665: W tensorflow/core/framework/op_kernel.cc:1401] OP_REQUIRES failed at whole_file_read_ops.cc:114 : Not found: NewRandomAccessFile failed to Create/Open: D:\user\dataset\KWS\TC-ResNet\google_speech_commands\splitted_data\train\no\73dda36a_nohash_1.wav : 系統找不到指定的檔案。

; No such file or directory
2023-08-08 16:19:55.861718: W tensorflow/core/framework/op_kernel.cc:1401] OP_REQUIRES failed at whole_file_read_ops.cc:114 : Not found: NewRandomAccessFile failed to Create/Open: D:\user\dataset\KWS\TC-ResNet\google_speech_commands\splitted_data\train\off\70a00e98_nohash_4.wav : 系統找不到指定的檔案
。
; No such file or directory
2023-08-08 16:19:55.861764: W tensorflow/core/framework/op_kernel.cc:1401] OP_REQUIRES failed at whole_file_read_ops.cc:114 : Not found: NewRandomAccessFile failed to Create/Open: D:\user\dataset\KWS\TC-ResNet\google_speech_commands\splitted_data\train\right\61a0d340_nohash_0.wav : 系統找不到指定的檔
案。
; No such file or directory

Can't find file in evaluation

Hi, I run the code and meet an error.

image
I didn't find when the file was created. Does this file need to be created manually?

Sincerely,
Chen

why "--height 49"

  • When generate pb file, we use this command:
    python freeze.py --num_classes 12 --checkpoint_path work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1/TCResNet8Model-30000 --output_name softmax --preprocess_method no_preprocessing --height 49 --width 40 --channels 1 TCResNet8Model --width_multiplier 1.0
  • so, what mean is "--height 49"?
  • MFCC extract with "--window_size_ms 40 --window_stride_ms 20"?
  • But, when train/valid/test, "--window_size_ms 30 --window_stride_ms 10"

could this code detect some keywords in a sentence.

I am a newer in automatic speech recognition. Now I have a task to detect keyword in real time and in real application. For example on speech command is "open the door". I wonder if this code could satisfy this problem.
My only problem about this code's use is, this model's train data only some special wav files and test only on special wav files. My 'special' means that it just a word no a real sentence.
I will read this paper one more again to understand it better.
Tks

Generate pb file command is error

correct is

python freeze.py --num_classes 12 --checkpoint_path work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1/TCResNet8Model-30000 --output_name softmax  --preprocess_method no_preprocessing --height 49 --width 40 --channels 1 TCResNet8Model --width_multiplier 1.0 

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.