Git Product home page Git Product logo

pinto0309 / tensorflow-bin Goto Github PK

View Code? Open in Web Editor NEW
492.0 21.0 113.0 14.3 MB

Prebuilt binary with Tensorflow Lite enabled. For RaspberryPi / Jetson Nano. Support for custom operations in MediaPipe. XNNPACK, XNNPACK Multi-Threads, FlexDelegate.

Home Page: https://qiita.com/PINTO

License: Apache License 2.0

Shell 100.00%
tensorflow raspberrypi wheel pip tensorflowlite python debian raspbian aarch64 armhf

tensorflow-bin's Introduction

Tensorflow-bin

Older versions of Wheel files can be obtained from the Previous version download script (GoogleDrive).

Prebuilt binary with Tensorflow Lite enabled. For RaspberryPi. Since the 64-bit OS for RaspberryPi has been officially released, I stopped building Wheel in armhf. If you need Wheel for armhf, please use this. TensorflowLite-bin

  • Support for Flex Delegate.
  • Support for XNNPACK.
  • Support for XNNPACK Half-precision Inference Doubles On-Device Inference Performance.

Python API packages

Device OS Distribution Architecture Python ver Note
RaspberryPi3/4 Raspbian/Debian Stretch armhf / armv7l 3.5.3 32bit, glibc2.24
RaspberryPi3/4 Raspbian/Debian Buster armhf / armv7l 3.7.3 / 2.7.16 32bit, glibc2.28
RaspberryPi3/4 RaspberryPiOS/Debian Buster aarch64 / armv8 3.7.3 64bit, glibc2.28
RaspberryPi3/4 Ubuntu 18.04 Bionic aarch64 / armv8 3.6.9 64bit, glibc2.27
RaspberryPi3/4 Ubuntu 20.04 Focal aarch64 / armv8 3.8.2 64bit, glibc2.31
RaspberryPi3/4,PiZero Ubuntu 21.04/Debian/RaspberryPiOS Hirsute/Bullseye aarch64 / armv8 3.9.x 64bit, glibc2.33/glibc2.31
RaspberryPi3/4 Ubuntu 22.04 Jammy aarch64 / armv8 3.10.x 64bit, glibc2.35
RaspberryPi4/5,PiZero Debian/RaspberryPiOS Bookworm aarch64 / armv8 3.11.x 64bit, glibc2.36

Minimal configuration stand-alone installer for Tensorflow Lite. https://github.com/PINTO0309/TensorflowLite-bin.git

Binary type

Python 2.x / 3.x + Tensorflow v1.15.0

.whl 4Threads Note
tensorflow-1.15.0-cp35-cp35m-linux_armv7l.whl Raspbian/Debian Stretch, glibc 2.24
tensorflow-1.15.0-cp27-cp27mu-linux_armv7l.whl Raspbian/Debian Buster, glibc 2.28
tensorflow-1.15.0-cp37-cp37m-linux_armv7l.whl Raspbian/Debian Buster, glibc 2.28
tensorflow-1.15.0-cp37-cp37m-linux_aarch64.whl Debian Buster, glibc 2.28

Python 3.x + Tensorflow v2

*FD = FlexDelegate, **XP = XNNPACK Float16 boost, ***MP = MediaPipe CustomOP, ****NP = Numpy

.whl FD XP MP NP Note
tensorflow-2.15.0.post1-cp39-none-linux_aarch64.whl 1.26 Ubuntu 21.04 glibc 2.33, Debian Bullseye glibc 2.31
tensorflow-2.15.0.post1-cp310-none-linux_aarch64.whl 1.26 Ubuntu 22.04 glibc 2.35
tensorflow-2.15.0.post1-cp311-none-linux_aarch64.whl 1.26 Debian Bookworm glibc 2.36

【Appendix】 C Library + Tensorflow v1.x.x / v2.x.x

The behavior is unconfirmed because I do not have C language implementation skills. Official tutorial on Tensorflow C binding generation

Appx1. C-API build procedure Native build procedure of Tensorflow v2.0.0 C API for RaspberryPi / arm64 devices (armhf / aarch64)

Appx2. C-API Usage

$ wget https://raw.githubusercontent.com/PINTO0309/Tensorflow-bin/main/C-library/2.2.0-armhf/install-buster.sh
$ ./install-buster.sh
Version Binary Note
v1.15.0 C-library/1.15.0-armhf/install-buster.sh Raspbian/Debian Buster, glibc 2.28
v1.15.0 C-library/1.15.0-aarch64/install-buster.sh Raspbian/Debian Buster, glibc 2.28
v2.2.0 C-library/2.2.0-armhf/install-buster.sh Raspbian/Debian Buster, glibc 2.28
v2.3.0 C-library/2.3.0-aarch64/install-buster.sh RaspberryPiOS/Raspbian/Debian Buster, glibc 2.28

Usage

Example of Python 3.x + Tensorflow v1 series

$ sudo apt-get install -y \
    libhdf5-dev libc-ares-dev libeigen3-dev gcc gfortran \
    libgfortran5 libatlas3-base libatlas-base-dev \
    libopenblas-dev libopenblas-base libblas-dev \
    liblapack-dev cython3 openmpi-bin libopenmpi-dev \
    libatlas-base-dev python3-dev
$ sudo pip3 install pip --upgrade
$ sudo pip3 install keras_applications==1.0.8 --no-deps
$ sudo pip3 install keras_preprocessing==1.1.0 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo pip3 install pybind11
$ pip3 install -U --user six wheel mock
$ sudo pip3 uninstall tensorflow
$ wget "https://raw.githubusercontent.com/PINTO0309/Tensorflow-bin/master/previous_versions/download_tensorflow-1.15.0-cp37-cp37m-linux_armv7l.sh"
$ ./download_tensorflow-1.15.0-cp37-cp37m-linux_armv7l.sh
$ sudo pip3 install tensorflow-1.15.0-cp37-cp37m-linux_armv7l.whl

Example of Python 3.x + Tensorflow v2 series

##### Bullseye, Ubuntu22.04
sudo apt update && sudo apt upgrade -y && \
sudo apt install -y \
    libhdf5-dev \
    unzip \
    pkg-config \
    python3-pip \
    cmake \
    make \
    git \
    python-is-python3 \
    wget \
    patchelf && \
pip install -U pip && \
pip install numpy==1.26.2 && \
pip install keras_applications==1.0.8 --no-deps && \
pip install keras_preprocessing==1.1.2 --no-deps && \
pip install h5py==3.6.0 && \
pip install pybind11==2.9.2 && \
pip install packaging && \
pip install protobuf==3.20.3 && \
pip install six wheel mock gdown
##### Bookworm
sudo apt update && sudo apt upgrade -y && \
sudo apt install -y \
    libhdf5-dev \
    unzip \
    pkg-config \
    python3-pip \
    cmake \
    make \
    git \
    python-is-python3 \
    wget \
    patchelf && \
pip install -U pip --break-system-packages && \
pip install numpy==1.26.2 --break-system-packages && \
pip install keras_applications==1.0.8 --no-deps --break-system-packages && \
pip install keras_preprocessing==1.1.2 --no-deps --break-system-packages && \
pip install h5py==3.10.0 --break-system-packages && \
pip install pybind11==2.9.2 --break-system-packages && \
pip install packaging --break-system-packages && \
pip install protobuf==3.20.3 --break-system-packages && \
pip install six wheel mock gdown --break-system-packages
pip uninstall tensorflow

TFVER=2.15.0.post1

PYVER=39
or
PYVER=310
or
PYVER=311

ARCH=`python -c 'import platform; print(platform.machine())'`
echo CPU ARCH: ${ARCH}

pip install \
--no-cache-dir \
https://github.com/PINTO0309/Tensorflow-bin/releases/download/v${TFVER}/tensorflow-${TFVER}-cp${PYVER}-none-linux_${ARCH}.whl

Operation check

Example of Python 3.x series

$ python -c 'import tensorflow as tf;print(tf.__version__)'
2.15.0.post1

Sample of MultiThread x4

  • Preparation of test environment
$ cd ~;mkdir test
$ curl https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/lite/examples/label_image/testdata/grace_hopper.bmp > ~/test/grace_hopper.bmp
$ curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz | tar xzv -C ~/test mobilenet_v1_1.0_224/labels.txt
$ mv ~/test/mobilenet_v1_1.0_224/labels.txt ~/test/
$ curl http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz | tar xzv -C ~/test
$ cp tensorflow/tensorflow/contrib/lite/examples/python/label_image.py ~/test
[Sample Code] label_image.py
import argparse
import numpy as np
import time

from PIL import Image

# Tensorflow -v1.12.0
#from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper

# Tensorflow v1.13.0+, v2.x.x
from tensorflow.lite.python import interpreter as interpreter_wrapper

def load_labels(filename):
  my_labels = []
  input_file = open(filename, 'r')
  for l in input_file:
    my_labels.append(l.strip())
  return my_labels
if __name__ == "__main__":
  floating_model = False
  parser = argparse.ArgumentParser()
  parser.add_argument("-i", "--image", default="/tmp/grace_hopper.bmp", \
    help="image to be classified")
  parser.add_argument("-m", "--model_file", \
    default="/tmp/mobilenet_v1_1.0_224_quant.tflite", \
    help=".tflite model to be executed")
  parser.add_argument("-l", "--label_file", default="/tmp/labels.txt", \
    help="name of file containing labels")
  parser.add_argument("--input_mean", default=127.5, help="input_mean")
  parser.add_argument("--input_std", default=127.5, \
    help="input standard deviation")
  parser.add_argument("--num_threads", default=1, help="number of threads")
  args = parser.parse_args()

  ### Tensorflow -v2.2.0
  #interpreter = interpreter_wrapper.Interpreter(model_path=args.model_file)
  ### Tensorflow v2.3.0+
  interpreter = interpreter_wrapper.Interpreter(model_path=args.model_file, num_threads=int(args.num_threads))

  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  output_details = interpreter.get_output_details()
  # check the type of the input tensor
  if input_details[0]['dtype'] == np.float32:
    floating_model = True
  # NxHxWxC, H:1, W:2
  height = input_details[0]['shape'][1]
  width = input_details[0]['shape'][2]
  img = Image.open(args.image)
  img = img.resize((width, height))
  # add N dim
  input_data = np.expand_dims(img, axis=0)
  if floating_model:
    input_data = (np.float32(input_data) - args.input_mean) / args.input_std

  ### Tensorflow -v2.2.0
  #interpreter.set_num_threads(int(args.num_threads))
  interpreter.set_tensor(input_details[0]['index'], input_data)

  start_time = time.time()
  interpreter.invoke()
  stop_time = time.time()

  output_data = interpreter.get_tensor(output_details[0]['index'])
  results = np.squeeze(output_data)
  top_k = results.argsort()[-5:][::-1]
  labels = load_labels(args.label_file)
  for i in top_k:
    if floating_model:
      print('{0:08.6f}'.format(float(results[i]))+":", labels[i])
    else:
      print('{0:08.6f}'.format(float(results[i]/255.0))+":", labels[i])

  print("time: ", stop_time - start_time)

  • Run test
$ cd ~/test
$ python3 label_image.py \
--num_threads 1 \
--image grace_hopper.bmp \
--model_file mobilenet_v1_1.0_224_quant.tflite \
--label_file labels.txt

0.415686: 653:military uniform
0.352941: 907:Windsor tie
0.058824: 668:mortarboard
0.035294: 458:bow tie, bow-tie, bowtie
0.035294: 835:suit, suit of clothes
time:  0.4152982234954834
$ cd ~/test
$ python3 label_image.py \
--num_threads 4 \
--image grace_hopper.bmp \
--model_file mobilenet_v1_1.0_224_quant.tflite \
--label_file labels.txt

0.415686: 653:military uniform
0.352941: 907:Windsor tie
0.058824: 668:mortarboard
0.035294: 458:bow tie, bow-tie, bowtie
0.035294: 835:suit, suit of clothes
time:  0.1647195816040039

Sample of MultiThread x4 - Real-time inference with a USB camera

002

Build Parameter

Tensorflow v1.11.0

============================================================

Tensorflow v1.11.0

============================================================

Python2.x - Bazel 0.17.2

$ sudo apt-get install -y openmpi-bin libopenmpi-dev libhdf5-dev

$ cd ~
$ git clone https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v1.11.0
$ ./configure

Please specify the location of python. [Default is /usr/bin/python]:


Found possible Python library paths:
  /usr/local/lib/python2.7/dist-packages
  /usr/local/lib
  /home/pi/tensorflow/tensorflow/contrib/lite/tools/make/gen/rpi_armv7l/lib
  /usr/lib/python2.7/dist-packages
  /opt/movidius/caffe/python
Please input the desired Python library path to use.  Default is [/usr/local/lib/python2.7/dist-packages]

Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]: y
No jemalloc as malloc support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]: n
No Google Cloud Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Hadoop File System support? [Y/n]: n
No Hadoop File System support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Amazon AWS Platform support? [Y/n]: n
No Amazon AWS Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Apache Kafka Platform support? [Y/n]: n
No Apache Kafka Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with XLA JIT support? [y/N]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with GDR support? [y/N]: n
No GDR support will be enabled for TensorFlow.

Do you wish to build TensorFlow with VERBS support? [y/N]: n
No VERBS support will be enabled for TensorFlow.

Do you wish to build TensorFlow with nGraph support? [y/N]: n
No nGraph support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
$ sudo bazel build --config opt --local_resources 1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ sudo ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
$ sudo pip2 install /tmp/tensorflow_pkg/tensorflow-1.11.0-cp27-cp27mu-linux_armv7l.whl

Python3.x- Bazel 0.17.2 + ZRAM + PythonAPI(MultiThread) Feb 23, 2019, Compilation work completed

$ sudo nano /etc/dphys-swapfile
CONF_SWAPFILE=2048
CONF_MAXSWAP=2048

$ sudo systemctl stop dphys-swapfile
$ sudo systemctl start dphys-swapfile

$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/zram.sh
$ chmod 755 zram.sh
$ sudo mv zram.sh /etc/init.d/
$ sudo update-rc.d zram.sh defaults
$ sudo reboot

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev
$ sudo pip3 install keras_applications==1.0.7 --no-deps
$ sudo pip3 install keras_preprocessing==1.0.9 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev

$ cd ~
$ git clone https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v1.11.0

Modify the program with reference to the following.

tensorflow/contrib/lite/examples/python/label_image.py
import argparse
import numpy as np
import time

from PIL import Image

from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper
def load_labels(filename):
  my_labels = []
  input_file = open(filename, 'r')
  for l in input_file:
    my_labels.append(l.strip())
  return my_labels
if __name__ == "__main__":
  floating_model = False
  parser = argparse.ArgumentParser()
  parser.add_argument("-i", "--image", default="/tmp/grace_hopper.bmp", \
    help="image to be classified")
  parser.add_argument("-m", "--model_file", \
    default="/tmp/mobilenet_v1_1.0_224_quant.tflite", \
    help=".tflite model to be executed")
  parser.add_argument("-l", "--label_file", default="/tmp/labels.txt", \
    help="name of file containing labels")
  parser.add_argument("--input_mean", default=127.5, help="input_mean")
  parser.add_argument("--input_std", default=127.5, \
    help="input standard deviation")
  parser.add_argument("--num_threads", default=1, help="number of threads")
  args = parser.parse_args()

  interpreter = interpreter_wrapper.Interpreter(model_path=args.model_file)
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  output_details = interpreter.get_output_details()
  # check the type of the input tensor
  if input_details[0]['dtype'] == np.float32:
    floating_model = True
  # NxHxWxC, H:1, W:2
  height = input_details[0]['shape'][1]
  width = input_details[0]['shape'][2]
  img = Image.open(args.image)
  img = img.resize((width, height))
  # add N dim
  input_data = np.expand_dims(img, axis=0)
  if floating_model:
    input_data = (np.float32(input_data) - args.input_mean) / args.input_std

  interpreter.set_num_threads(int(args.num_threads))
  interpreter.set_tensor(input_details[0]['index'], input_data)

  start_time = time.time()
  interpreter.invoke()
  stop_time = time.time()

  output_data = interpreter.get_tensor(output_details[0]['index'])
  results = np.squeeze(output_data)
  top_k = results.argsort()[-5:][::-1]
  labels = load_labels(args.label_file)
  for i in top_k:
    if floating_model:
      print('{0:08.6f}'.format(float(results[i]))+":", labels[i])
    else:
      print('{0:08.6f}'.format(float(results[i]/255.0))+":", labels[i])

  print("time: ", stop_time - start_time)
tensorflow/contrib/lite/python/interpreter.py
#Add the following two lines to the last line

  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
//Corrected the vicinity of the last line as follows

  PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
//Modified the middle of the logic as follows

  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.

$ ./configure

Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /usr/local/lib/python3.5/dist-packages
  /opt/movidius/caffe/python
Please input the desired Python library path to use.  Default is [/usr/local/lib] /usr/local/lib/python3.5/dist-packages

Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]: y
No jemalloc as malloc support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]: n
No Google Cloud Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Hadoop File System support? [Y/n]: n
No Hadoop File System support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Amazon AWS Platform support? [Y/n]: n
No Amazon AWS Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Apache Kafka Platform support? [Y/n]: n
No Apache Kafka Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with XLA JIT support? [y/N]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with GDR support? [y/N]: n
No GDR support will be enabled for TensorFlow.

Do you wish to build TensorFlow with VERBS support? [y/N]: n
No VERBS support will be enabled for TensorFlow.

Do you wish to build TensorFlow with nGraph support? [y/N]: n
No nGraph support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See tools/bazel.rc for more details.
    --config=mkl            # Build with MKL support.
    --config=monolithic     # Config for mostly static monolithic build.
Configuration finished
$ sudo bazel build --config opt --local_resources 1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ sudo -s
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo pip3 install /tmp/tensorflow_pkg/tensorflow-1.11.0-cp35-cp35m-linux_armv7l.whl

Python3.x + jemalloc + MPI + MultiThread [C++ Only]

Edit tensorflow/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc Line139 / Line140, Line261.

  MPIRendezvousMgr* mgr =
      reinterpret_cast<MPIRendezvousMgr*>(this->rendezvous_mgr_);
- mgr->QueueRequest(parsed.FullKey().ToString(), step_id_,
-                   std::move(request_call), rendezvous_call);
+ mgr->QueueRequest(string(parsed.FullKey()), step_id_, std::move(request_call),
+                   rendezvous_call);
}
 MPIRemoteRendezvous::~MPIRemoteRendezvous() {}


        std::function<MPISendTensorCall*()> res = std::bind(
            send_cb, status, send_args, recv_args, val, is_dead, mpi_send_call);
-       SendQueueEntry req(parsed.FullKey().ToString().c_str(), std::move(res));
+       SendQueueEntry req(string(parsed.FullKey()), std::move(res));
         this->QueueSendRequest(req);

Edit tensorflow/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h Line74

  void Init(const Rendezvous::ParsedKey& parsed, const int64 step_id,
            const bool is_dead) {
-   mRes_.set_key(parsed.FullKey().ToString());
+   mRes_.set_key(string(parsed.FullKey()));
    mRes_.set_step_id(step_id);
    mRes_.mutable_response()->set_is_dead(is_dead);
    mRes_.mutable_response()->set_send_start_micros(

Edit tensorflow/tensorflow/contrib/lite/interpreter.cc Line127.

-  context_.recommended_num_threads = -1;
+  context_.recommended_num_threads = 4;
$ sudo apt-get install -y libhdf5-dev
$ sudo pip3 install keras_applications==1.0.4 --no-deps
$ sudo pip3 install keras_preprocessing==1.0.2 --no-deps
$ sudo pip3 install h5py==2.8.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev

$ cd ~
$ git clone https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v1.11.0
$ ./configure

Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /usr/local/lib/python3.5/dist-packages
  /opt/movidius/caffe/python
Please input the desired Python library path to use.  Default is [/usr/local/lib] /usr/local/lib/python3.5/dist-packages

Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]: y
jemalloc as malloc support will be enabled for Tensorflow.

Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]: n
No Google Cloud Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Hadoop File System support? [Y/n]: n
No Hadoop File System support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Amazon AWS Platform support? [Y/n]: n
No Amazon AWS Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Apache Kafka Platform support? [Y/n]: n
No Apache Kafka Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with XLA JIT support? [y/N]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with GDR support? [y/N]: n
No GDR support will be enabled for TensorFlow.

Do you wish to build TensorFlow with VERBS support? [y/N]: n
No VERBS support will be enabled for TensorFlow.

Do you wish to build TensorFlow with nGraph support? [y/N]: n
No nGraph support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: y
MPI support will be enabled for Tensorflow.

Please specify the MPI toolkit folder. [Default is /usr]: /usr/lib/arm-linux-gnueabihf/openmpi

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See tools/bazel.rc for more details.
    --config=mkl            # Build with MKL support.
    --config=monolithic     # Config for mostly static monolithic build.
Configuration finished
$ sudo bazel build --config opt --local_resources 1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package

Python3.x + jemalloc + XLA JIT (Build impossible)

$ sudo apt-get install -y libhdf5-dev
$ sudo pip3 install keras_applications==1.0.4 --no-deps
$ sudo pip3 install keras_preprocessing==1.0.2 --no-deps
$ sudo pip3 install h5py==2.8.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ JAVA_OPTIONS=-Xmx256M

$ cd ~
$ git clone https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v1.11.0
$ ./configure

Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /usr/local/lib/python3.5/dist-packages
  /opt/movidius/caffe/python
Please input the desired Python library path to use.  Default is [/usr/local/lib] /usr/local/lib/python3.5/dist-packages

Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]: y
jemalloc as malloc support will be enabled for Tensorflow.

Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]: n
No Google Cloud Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Hadoop File System support? [Y/n]: n
No Hadoop File System support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Amazon AWS Platform support? [Y/n]: n
No Amazon AWS Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with Apache Kafka Platform support? [Y/n]: n
No Apache Kafka Platform support will be enabled for TensorFlow.

Do you wish to build TensorFlow with XLA JIT support? [y/N]: y
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with GDR support? [y/N]: n
No GDR support will be enabled for TensorFlow.

Do you wish to build TensorFlow with VERBS support? [y/N]: n
No VERBS support will be enabled for TensorFlow.

Do you wish to build TensorFlow with nGraph support? [y/N]: n
No nGraph support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
MPI support will be enabled for Tensorflow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See tools/bazel.rc for more details.
    --config=mkl            # Build with MKL support.
    --config=monolithic     # Config for mostly static monolithic build.
Configuration finished
$ sudo bazel build --config opt --local_resources 1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package

Python3.x + TX2 aarch64 - Bazel 0.18.1 (JetPack-L4T-3.3-linux-x64_b39)

- L4T R28.2.1(TX2 / TX2i)
- L4T R28.2(TX1)
- CUDA 9.0
- cuDNN 7.1.5
- TensorRT 4.0
- VisionWorks 1.6

tensorflow/tensorflow#21574 (comment) tensorflow/serving#832 https://docs.nvidia.com/deeplearning/sdk/nccl-archived/nccl_2213/nccl-install-guide/index.html

build --action_env PYTHON_BIN_PATH="/usr/bin/python3"
build --action_env PYTHON_LIB_PATH="/usr/local/lib/python3.5/dist-packages"
build --python_path="/usr/bin/python3"
build --define with_jemalloc=true
build:gcp --define with_gcp_support=true
build:hdfs --define with_hdfs_support=true
build:aws --define with_aws_support=true
build:kafka --define with_kafka_support=true
build:xla --define with_xla_support=true
build:gdr --define with_gdr_support=true
build:verbs --define with_verbs_support=true
build:ngraph --define with_ngraph_support=true
build --action_env TF_NEED_OPENCL_SYCL="0"
build --action_env TF_NEED_CUDA="1"
build --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-9.0"
build --action_env TF_CUDA_VERSION="9.0"
build --action_env CUDNN_INSTALL_PATH="/usr/lib/aarch64-linux-gnu"
build --action_env TF_CUDNN_VERSION="7"
build --action_env NCCL_INSTALL_PATH="/usr/local"
build --action_env TF_NCCL_VERSION="2"
build --action_env TF_CUDA_COMPUTE_CAPABILITIES="3.5,7.0"
build --action_env LD_LIBRARY_PATH="/usr/local/cuda-9.0/lib64:../src/.libs"
build --action_env TF_CUDA_CLANG="0"
build --action_env GCC_HOST_COMPILER_PATH="/usr/bin/gcc"
build --config=cuda
test --config=cuda
build --define grpc_no_ares=true
build:opt --copt=-march=native
build:opt --host_copt=-march=native
build:opt --define with_default_optimizations=true
$ sudo apt-get install -y libhdf5-dev
$ sudo pip3 install keras_applications==1.0.4 --no-deps
$ sudo pip3 install keras_preprocessing==1.0.2 --no-deps
$ sudo pip3 install h5py==2.8.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ bazel build -c opt --config=cuda --local_resources 3072.0,4.0,1.0 --verbose_failures //tensorflow/tools/pip_package:build_pip_package
Tensorflow v1.12.0

============================================================

Tensorflow v1.12.0 - Bazel 0.18.1

============================================================

Python3.x (Nov 15, 2018 Under construction)

  • tensorflow/BUILD
config_setting(
    name = "no_aws_support",
    define_values = {"no_aws_support": "false"},
    visibility = ["//visibility:public"],
)

config_setting(
    name = "no_gcp_support",
    define_values = {"no_gcp_support": "false"},
    visibility = ["//visibility:public"],
)

config_setting(
    name = "no_hdfs_support",
    define_values = {"no_hdfs_support": "false"},
    visibility = ["//visibility:public"],
)

config_setting(
    name = "no_ignite_support",
    define_values = {"no_ignite_support": "false"},
    visibility = ["//visibility:public"],
)

config_setting(
    name = "no_kafka_support",
    define_values = {"no_kafka_support": "false"},
    visibility = ["//visibility:public"],
)
  • bazel.rc
# Options to disable default on features
build:noaws --define=no_aws_support=true
build:nogcp --define=no_gcp_support=true
build:nohdfs --define=no_hdfs_support=true
build:nokafka --define=no_kafka_support=true
build:noignite --define=no_ignite_support=true
  • configure.py
  #set_build_var(environ_cp, 'TF_NEED_IGNITE', 'Apache Ignite',
  #              'with_ignite_support', True, 'ignite')


  ## On Windows, we don't have MKL support and the build is always monolithic.
  ## So no need to print the following message.
  ## TODO(pcloudy): remove the following if check when they make sense on Windows
  #if not is_windows():
  #  print('Preconfigured Bazel build configs. You can use any of the below by '
  #        'adding "--config=<>" to your build command. See tools/bazel.rc for '
  #        'more details.')
  #  config_info_line('mkl', 'Build with MKL support.')
  #  config_info_line('monolithic', 'Config for mostly static monolithic build.')
  #  config_info_line('gdr', 'Build with GDR support.')
  #  config_info_line('verbs', 'Build with libverbs support.')
  #  config_info_line('ngraph', 'Build with Intel nGraph support.')
  print('Preconfigured Bazel build configs. You can use any of the below by '
        'adding "--config=<>" to your build command. See .bazelrc for more '
        'details.')
  config_info_line('mkl', 'Build with MKL support.')
  config_info_line('monolithic', 'Config for mostly static monolithic build.')
  config_info_line('gdr', 'Build with GDR support.')
  config_info_line('verbs', 'Build with libverbs support.')
  config_info_line('ngraph', 'Build with Intel nGraph support.')

  print('Preconfigured Bazel build configs to DISABLE default on features:')
  config_info_line('noaws', 'Disable AWS S3 filesystem support.')
  config_info_line('nogcp', 'Disable GCP support.')
  config_info_line('nohdfs', 'Disable HDFS support.')
  config_info_line('noignite', 'Disable Apacha Ignite support.')
  config_info_line('nokafka', 'Disable Apache Kafka support.')
# Description:
#   contains parts of TensorFlow that are experimental or unstable and which are not supported.

licenses(["notice"])  # Apache 2.0

package(default_visibility = ["//tensorflow:__subpackages__"])

load("//third_party/mpi:mpi.bzl", "if_mpi")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//tensorflow:tensorflow.bzl", "if_not_windows")
load("//tensorflow:tensorflow.bzl", "if_not_windows_cuda")

py_library(
    name = "contrib_py",
    srcs = glob(
        ["**/*.py"],
        exclude = [
            "**/*_test.py",
        ],
    ),
    srcs_version = "PY2AND3",
    visibility = ["//visibility:public"],
    deps = [
        "//tensorflow/contrib/all_reduce",
        "//tensorflow/contrib/batching:batch_py",
        "//tensorflow/contrib/bayesflow:bayesflow_py",
        "//tensorflow/contrib/boosted_trees:init_py",
        "//tensorflow/contrib/checkpoint/python:checkpoint",
        "//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
        "//tensorflow/contrib/coder:coder_py",
        "//tensorflow/contrib/compiler:compiler_py",
        "//tensorflow/contrib/compiler:xla",
        "//tensorflow/contrib/autograph",
        "//tensorflow/contrib/constrained_optimization",
        "//tensorflow/contrib/copy_graph:copy_graph_py",
        "//tensorflow/contrib/crf:crf_py",
        "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py",
        "//tensorflow/contrib/data",
        "//tensorflow/contrib/deprecated:deprecated_py",
        "//tensorflow/contrib/distribute:distribute",
        "//tensorflow/contrib/distributions:distributions_py",
        "//tensorflow/contrib/eager/python:tfe",
        "//tensorflow/contrib/estimator:estimator_py",
        "//tensorflow/contrib/factorization:factorization_py",
        "//tensorflow/contrib/feature_column:feature_column_py",
        "//tensorflow/contrib/framework:framework_py",
        "//tensorflow/contrib/gan",
        "//tensorflow/contrib/graph_editor:graph_editor_py",
        "//tensorflow/contrib/grid_rnn:grid_rnn_py",
        "//tensorflow/contrib/hadoop",
        "//tensorflow/contrib/hooks",
        "//tensorflow/contrib/image:distort_image_py",
        "//tensorflow/contrib/image:image_py",
        "//tensorflow/contrib/image:single_image_random_dot_stereograms_py",
        "//tensorflow/contrib/input_pipeline:input_pipeline_py",
        "//tensorflow/contrib/integrate:integrate_py",
        "//tensorflow/contrib/keras",
        "//tensorflow/contrib/kernel_methods",
        "//tensorflow/contrib/labeled_tensor",
        "//tensorflow/contrib/layers:layers_py",
        "//tensorflow/contrib/learn",
        "//tensorflow/contrib/legacy_seq2seq:seq2seq_py",
        "//tensorflow/contrib/libsvm",
        "//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
        "//tensorflow/contrib/linear_optimizer:sdca_ops_py",
        "//tensorflow/contrib/lite/python:lite",
        "//tensorflow/contrib/lookup:lookup_py",
        "//tensorflow/contrib/losses:losses_py",
        "//tensorflow/contrib/losses:metric_learning_py",
        "//tensorflow/contrib/memory_stats:memory_stats_py",
        "//tensorflow/contrib/meta_graph_transform",
        "//tensorflow/contrib/metrics:metrics_py",
        "//tensorflow/contrib/mixed_precision:mixed_precision",
        "//tensorflow/contrib/model_pruning",
        "//tensorflow/contrib/nccl:nccl_py",
        "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py",
        "//tensorflow/contrib/nn:nn_py",
        "//tensorflow/contrib/opt:opt_py",
        "//tensorflow/contrib/optimizer_v2:optimizer_v2_py",
        "//tensorflow/contrib/periodic_resample:init_py",
        "//tensorflow/contrib/predictor",
        "//tensorflow/contrib/proto",
        "//tensorflow/contrib/quantization:quantization_py",
        "//tensorflow/contrib/quantize:quantize_graph",
        "//tensorflow/contrib/receptive_field:receptive_field_py",
        "//tensorflow/contrib/recurrent:recurrent_py",
        "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py",
        "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
        "//tensorflow/contrib/resampler:resampler_py",
        "//tensorflow/contrib/rnn:rnn_py",
        "//tensorflow/contrib/rpc",
        "//tensorflow/contrib/saved_model:saved_model_py",
        "//tensorflow/contrib/seq2seq:seq2seq_py",
        "//tensorflow/contrib/signal:signal_py",
        "//tensorflow/contrib/slim",
        "//tensorflow/contrib/slim:nets",
        "//tensorflow/contrib/solvers:solvers_py",
        "//tensorflow/contrib/sparsemax:sparsemax_py",
        "//tensorflow/contrib/specs",
        "//tensorflow/contrib/staging",
        "//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
        "//tensorflow/contrib/stateless",
        "//tensorflow/contrib/summary:summary",
        "//tensorflow/contrib/tensor_forest:init_py",
        "//tensorflow/contrib/tensorboard",
        "//tensorflow/contrib/testing:testing_py",
        "//tensorflow/contrib/text:text_py",
        "//tensorflow/contrib/tfprof",
        "//tensorflow/contrib/timeseries",
        "//tensorflow/contrib/tpu",
        "//tensorflow/contrib/training:training_py",
        "//tensorflow/contrib/util:util_py",
        "//tensorflow/python:util",
        "//tensorflow/python/estimator:estimator_py",
    ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_kafka_support": [],
        "//conditions:default": [
            "//tensorflow/contrib/kafka",
        ],
    }) + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_aws_support": [],
        "//conditions:default": [
             "//tensorflow/contrib/kinesis",
        ],
    }) + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//conditions:default": [
            "//tensorflow/contrib/fused_conv:fused_conv_py",
             "//tensorflow/contrib/tensorrt:init_py",
             "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
         ],
     }) + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_gcp_support": [],
        "//conditions:default": [
            "//tensorflow/contrib/bigtable",
            "//tensorflow/contrib/cloud:cloud_py",
        ],
    }) + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_ignite_support": [],
        "//conditions:default": [
             "//tensorflow/contrib/ignite",
         ],
     }),
 )

cc_library(
    name = "contrib_kernels",
    visibility = ["//visibility:public"],
    deps = [
        "//tensorflow/contrib/boosted_trees:boosted_trees_kernels",
        "//tensorflow/contrib/coder:all_kernels",
        "//tensorflow/contrib/factorization/kernels:all_kernels",
        "//tensorflow/contrib/hadoop:dataset_kernels",
        "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
        "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
        "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_kernels",
        "//tensorflow/contrib/rnn:all_kernels",
        "//tensorflow/contrib/seq2seq:beam_search_ops_kernels",
        "//tensorflow/contrib/tensor_forest:model_ops_kernels",
        "//tensorflow/contrib/tensor_forest:stats_ops_kernels",
        "//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
        "//tensorflow/contrib/text:all_kernels",
     ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([
         "//tensorflow/contrib/nccl:nccl_kernels",
     ]) + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
         "//tensorflow:linux_s390x": [],
         "//tensorflow:windows": [],
        "//tensorflow:no_kafka_support": [],
         "//conditions:default": [
             "//tensorflow/contrib/kafka:dataset_kernels",
        ],
    }) + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_aws_support": [],
        "//conditions:default": [
             "//tensorflow/contrib/kinesis:dataset_kernels",
         ],
    }) + if_not_windows([
        "//tensorflow/contrib/tensorrt:trt_engine_op_kernel",
    ]),
 )

cc_library(
    name = "contrib_ops_op_lib",
    visibility = ["//visibility:public"],
    deps = [
        "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib",
        "//tensorflow/contrib/coder:all_ops",
        "//tensorflow/contrib/factorization:all_ops",
        "//tensorflow/contrib/framework:all_ops",
        "//tensorflow/contrib/hadoop:dataset_ops_op_lib",
        "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
        "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
        "//tensorflow/contrib/nccl:nccl_ops_op_lib",
        "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib",
        "//tensorflow/contrib/rnn:all_ops",
        "//tensorflow/contrib/seq2seq:beam_search_ops_op_lib",
        "//tensorflow/contrib/tensor_forest:model_ops_op_lib",
        "//tensorflow/contrib/tensor_forest:stats_ops_op_lib",
        "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
        "//tensorflow/contrib/text:all_ops",
        "//tensorflow/contrib/tpu:all_ops",
    ] + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
         "//tensorflow:linux_s390x": [],
         "//tensorflow:windows": [],
        "//tensorflow:no_kafka_support": [],
         "//conditions:default": [
             "//tensorflow/contrib/kafka:dataset_ops_op_lib",
         ],
     }) + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_aws_support": [],
        "//conditions:default": [
            "//tensorflow/contrib/kinesis:dataset_ops_op_lib",
        ],
    }) + if_not_windows([
        "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
    ]) + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_ignite_support": [],
        "//conditions:default": [
             "//tensorflow/contrib/ignite:dataset_ops_op_lib",
         ],
     }),
 )
  • tensorflow/core/platform/default/build_config.bzl
# Platform-specific build configurations.

load("@protobuf_archive//:protobuf.bzl", "proto_gen")
load("//tensorflow:tensorflow.bzl", "if_not_mobile")
load("//tensorflow:tensorflow.bzl", "if_windows")
load("//tensorflow:tensorflow.bzl", "if_not_windows")
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load(
    "//third_party/mkl:build_defs.bzl",
    "if_mkl_ml",
)

# Appends a suffix to a list of deps.
def tf_deps(deps, suffix):
    tf_deps = []

    # If the package name is in shorthand form (ie: does not contain a ':'),
    # expand it to the full name.
    for dep in deps:
        tf_dep = dep

        if not ":" in dep:
            dep_pieces = dep.split("/")
            tf_dep += ":" + dep_pieces[len(dep_pieces) - 1]

        tf_deps += [tf_dep + suffix]

    return tf_deps

# Modified from @cython//:Tools/rules.bzl
def pyx_library(
        name,
        deps = [],
        py_deps = [],
        srcs = [],
        **kwargs):
    """Compiles a group of .pyx / .pxd / .py files.

    First runs Cython to create .cpp files for each input .pyx or .py + .pxd
    pair. Then builds a shared object for each, passing "deps" to each cc_binary
    rule (includes Python headers by default). Finally, creates a py_library rule
    with the shared objects and any pure Python "srcs", with py_deps as its
    dependencies; the shared objects can be imported like normal Python files.

    Args:
      name: Name for the rule.
      deps: C/C++ dependencies of the Cython (e.g. Numpy headers).
      py_deps: Pure Python dependencies of the final library.
      srcs: .py, .pyx, or .pxd files to either compile or pass through.
      **kwargs: Extra keyword arguments passed to the py_library.
    """

    # First filter out files that should be run compiled vs. passed through.
    py_srcs = []
    pyx_srcs = []
    pxd_srcs = []
    for src in srcs:
        if src.endswith(".pyx") or (src.endswith(".py") and
                                    src[:-3] + ".pxd" in srcs):
            pyx_srcs.append(src)
        elif src.endswith(".py"):
            py_srcs.append(src)
        else:
            pxd_srcs.append(src)
        if src.endswith("__init__.py"):
            pxd_srcs.append(src)

    # Invoke cython to produce the shared object libraries.
    for filename in pyx_srcs:
        native.genrule(
            name = filename + "_cython_translation",
            srcs = [filename],
            outs = [filename.split(".")[0] + ".cpp"],
            # Optionally use PYTHON_BIN_PATH on Linux platforms so that python 3
            # works. Windows has issues with cython_binary so skip PYTHON_BIN_PATH.
            cmd = "PYTHONHASHSEED=0 $(location @cython//:cython_binary) --cplus $(SRCS) --output-file $(OUTS)",
            tools = ["@cython//:cython_binary"] + pxd_srcs,
        )

    shared_objects = []
    for src in pyx_srcs:
        stem = src.split(".")[0]
        shared_object_name = stem + ".so"
        native.cc_binary(
            name = shared_object_name,
            srcs = [stem + ".cpp"],
            deps = deps + ["//third_party/python_runtime:headers"],
            linkshared = 1,
        )
        shared_objects.append(shared_object_name)

    # Now create a py_library with these shared objects as data.
    native.py_library(
        name = name,
        srcs = py_srcs,
        deps = py_deps,
        srcs_version = "PY2AND3",
        data = shared_objects,
        **kwargs
    )

def _proto_cc_hdrs(srcs, use_grpc_plugin = False):
    ret = [s[:-len(".proto")] + ".pb.h" for s in srcs]
    if use_grpc_plugin:
        ret += [s[:-len(".proto")] + ".grpc.pb.h" for s in srcs]
    return ret

def _proto_cc_srcs(srcs, use_grpc_plugin = False):
    ret = [s[:-len(".proto")] + ".pb.cc" for s in srcs]
    if use_grpc_plugin:
        ret += [s[:-len(".proto")] + ".grpc.pb.cc" for s in srcs]
    return ret

def _proto_py_outs(srcs, use_grpc_plugin = False):
    ret = [s[:-len(".proto")] + "_pb2.py" for s in srcs]
    if use_grpc_plugin:
        ret += [s[:-len(".proto")] + "_pb2_grpc.py" for s in srcs]
    return ret

# Re-defined protocol buffer rule to allow building "header only" protocol
# buffers, to avoid duplicate registrations. Also allows non-iterable cc_libs
# containing select() statements.
def cc_proto_library(
        name,
        srcs = [],
        deps = [],
        cc_libs = [],
        include = None,
        protoc = "@protobuf_archive//:protoc",
        internal_bootstrap_hack = False,
        use_grpc_plugin = False,
        use_grpc_namespace = False,
        default_header = False,
        **kargs):
    """Bazel rule to create a C++ protobuf library from proto source files.

    Args:
      name: the name of the cc_proto_library.
      srcs: the .proto files of the cc_proto_library.
      deps: a list of dependency labels; must be cc_proto_library.
      cc_libs: a list of other cc_library targets depended by the generated
          cc_library.
      include: a string indicating the include path of the .proto files.
      protoc: the label of the protocol compiler to generate the sources.
      internal_bootstrap_hack: a flag indicate the cc_proto_library is used only
          for bootstraping. When it is set to True, no files will be generated.
          The rule will simply be a provider for .proto files, so that other
          cc_proto_library can depend on it.
      use_grpc_plugin: a flag to indicate whether to call the grpc C++ plugin
          when processing the proto files.
      default_header: Controls the naming of generated rules. If True, the `name`
          rule will be header-only, and an _impl rule will contain the
          implementation. Otherwise the header-only rule (name + "_headers_only")
          must be referred to explicitly.
      **kargs: other keyword arguments that are passed to cc_library.
    """

    includes = []
    if include != None:
        includes = [include]

    if internal_bootstrap_hack:
        # For pre-checked-in generated files, we add the internal_bootstrap_hack
        # which will skip the codegen action.
        proto_gen(
            name = name + "_genproto",
            srcs = srcs,
            includes = includes,
            protoc = protoc,
            visibility = ["//visibility:public"],
            deps = [s + "_genproto" for s in deps],
        )

        # An empty cc_library to make rule dependency consistent.
        native.cc_library(
            name = name,
            **kargs
        )
        return

    grpc_cpp_plugin = None
    plugin_options = []
    if use_grpc_plugin:
        grpc_cpp_plugin = "//external:grpc_cpp_plugin"
        if use_grpc_namespace:
            plugin_options = ["services_namespace=grpc"]

    gen_srcs = _proto_cc_srcs(srcs, use_grpc_plugin)
    gen_hdrs = _proto_cc_hdrs(srcs, use_grpc_plugin)
    outs = gen_srcs + gen_hdrs

    proto_gen(
        name = name + "_genproto",
        srcs = srcs,
        outs = outs,
        gen_cc = 1,
        includes = includes,
        plugin = grpc_cpp_plugin,
        plugin_language = "grpc",
        plugin_options = plugin_options,
        protoc = protoc,
        visibility = ["//visibility:public"],
        deps = [s + "_genproto" for s in deps],
    )

    if use_grpc_plugin:
        cc_libs += select({
            "//tensorflow:linux_s390x": ["//external:grpc_lib_unsecure"],
            "//conditions:default": ["//external:grpc_lib"],
        })

    if default_header:
        header_only_name = name
        impl_name = name + "_impl"
    else:
        header_only_name = name + "_headers_only"
        impl_name = name

    native.cc_library(
        name = impl_name,
        srcs = gen_srcs,
        hdrs = gen_hdrs,
        deps = cc_libs + deps,
        includes = includes,
        **kargs
    )
    native.cc_library(
        name = header_only_name,
        deps = ["@protobuf_archive//:protobuf_headers"] + if_static([impl_name]),
        hdrs = gen_hdrs,
        **kargs
    )

# Re-defined protocol buffer rule to bring in the change introduced in commit
# https://github.com/google/protobuf/commit/294b5758c373cbab4b72f35f4cb62dc1d8332b68
# which was not part of a stable protobuf release in 04/2018.
# TODO(jsimsa): Remove this once the protobuf dependency version is updated
# to include the above commit.
def py_proto_library(
        name,
        srcs = [],
        deps = [],
        py_libs = [],
        py_extra_srcs = [],
        include = None,
        default_runtime = "@protobuf_archive//:protobuf_python",
        protoc = "@protobuf_archive//:protoc",
        use_grpc_plugin = False,
        **kargs):
    """Bazel rule to create a Python protobuf library from proto source files

    NOTE: the rule is only an internal workaround to generate protos. The
    interface may change and the rule may be removed when bazel has introduced
    the native rule.

    Args:
      name: the name of the py_proto_library.
      srcs: the .proto files of the py_proto_library.
      deps: a list of dependency labels; must be py_proto_library.
      py_libs: a list of other py_library targets depended by the generated
          py_library.
      py_extra_srcs: extra source files that will be added to the output
          py_library. This attribute is used for internal bootstrapping.
      include: a string indicating the include path of the .proto files.
      default_runtime: the implicitly default runtime which will be depended on by
          the generated py_library target.
      protoc: the label of the protocol compiler to generate the sources.
      use_grpc_plugin: a flag to indicate whether to call the Python C++ plugin
          when processing the proto files.
      **kargs: other keyword arguments that are passed to cc_library.
    """
    outs = _proto_py_outs(srcs, use_grpc_plugin)

    includes = []
    if include != None:
        includes = [include]

    grpc_python_plugin = None
    if use_grpc_plugin:
        grpc_python_plugin = "//external:grpc_python_plugin"
        # Note: Generated grpc code depends on Python grpc module. This dependency
        # is not explicitly listed in py_libs. Instead, host system is assumed to
        # have grpc installed.

    proto_gen(
        name = name + "_genproto",
        srcs = srcs,
        outs = outs,
        gen_py = 1,
        includes = includes,
        plugin = grpc_python_plugin,
        plugin_language = "grpc",
        protoc = protoc,
        visibility = ["//visibility:public"],
        deps = [s + "_genproto" for s in deps],
    )

    if default_runtime and not default_runtime in py_libs + deps:
        py_libs = py_libs + [default_runtime]

    native.py_library(
        name = name,
        srcs = outs + py_extra_srcs,
        deps = py_libs + deps,
        imports = includes,
        **kargs
    )

def tf_proto_library_cc(
        name,
        srcs = [],
        has_services = None,
        protodeps = [],
        visibility = [],
        testonly = 0,
        cc_libs = [],
        cc_stubby_versions = None,
        cc_grpc_version = None,
        j2objc_api_version = 1,
        cc_api_version = 2,
        dart_api_version = 2,
        java_api_version = 2,
        py_api_version = 2,
        js_api_version = 2,
        js_codegen = "jspb",
        default_header = False):
    js_codegen = js_codegen  # unused argument
    js_api_version = js_api_version  # unused argument
    native.filegroup(
        name = name + "_proto_srcs",
        srcs = srcs + tf_deps(protodeps, "_proto_srcs"),
        testonly = testonly,
        visibility = visibility,
    )

    use_grpc_plugin = None
    if cc_grpc_version:
        use_grpc_plugin = True

    cc_deps = tf_deps(protodeps, "_cc")
    cc_name = name + "_cc"
    if not srcs:
        # This is a collection of sub-libraries. Build header-only and impl
        # libraries containing all the sources.
        proto_gen(
            name = cc_name + "_genproto",
            protoc = "@protobuf_archive//:protoc",
            visibility = ["//visibility:public"],
            deps = [s + "_genproto" for s in cc_deps],
        )
        native.cc_library(
            name = cc_name,
            deps = cc_deps + ["@protobuf_archive//:protobuf_headers"] + if_static([name + "_cc_impl"]),
            testonly = testonly,
            visibility = visibility,
        )
        native.cc_library(
            name = cc_name + "_impl",
            deps = [s + "_impl" for s in cc_deps] + ["@protobuf_archive//:cc_wkt_protos"],
        )

        return

    cc_proto_library(
        name = cc_name,
        testonly = testonly,
        srcs = srcs,
        cc_libs = cc_libs + if_static(
            ["@protobuf_archive//:protobuf"],
            ["@protobuf_archive//:protobuf_headers"],
        ),
        copts = if_not_windows([
            "-Wno-unknown-warning-option",
            "-Wno-unused-but-set-variable",
            "-Wno-sign-compare",
        ]),
        default_header = default_header,
        protoc = "@protobuf_archive//:protoc",
        use_grpc_plugin = use_grpc_plugin,
        visibility = visibility,
        deps = cc_deps + ["@protobuf_archive//:cc_wkt_protos"],
    )

def tf_proto_library_py(
        name,
        srcs = [],
        protodeps = [],
        deps = [],
        visibility = [],
        testonly = 0,
        srcs_version = "PY2AND3",
        use_grpc_plugin = False):
    py_deps = tf_deps(protodeps, "_py")
    py_name = name + "_py"
    if not srcs:
        # This is a collection of sub-libraries. Build header-only and impl
        # libraries containing all the sources.
        proto_gen(
            name = py_name + "_genproto",
            protoc = "@protobuf_archive//:protoc",
            visibility = ["//visibility:public"],
            deps = [s + "_genproto" for s in py_deps],
        )
        native.py_library(
            name = py_name,
            deps = py_deps + ["@protobuf_archive//:protobuf_python"],
            testonly = testonly,
            visibility = visibility,
        )
        return

    py_proto_library(
        name = py_name,
        testonly = testonly,
        srcs = srcs,
        default_runtime = "@protobuf_archive//:protobuf_python",
        protoc = "@protobuf_archive//:protoc",
        srcs_version = srcs_version,
        use_grpc_plugin = use_grpc_plugin,
        visibility = visibility,
        deps = deps + py_deps + ["@protobuf_archive//:protobuf_python"],
    )

def tf_jspb_proto_library(**kwargs):
    pass

def tf_nano_proto_library(**kwargs):
    pass

def tf_proto_library(
        name,
        srcs = [],
        has_services = None,
        protodeps = [],
        visibility = [],
        testonly = 0,
        cc_libs = [],
        cc_api_version = 2,
        cc_grpc_version = None,
        dart_api_version = 2,
        j2objc_api_version = 1,
        java_api_version = 2,
        py_api_version = 2,
        js_api_version = 2,
        js_codegen = "jspb",
        provide_cc_alias = False,
        default_header = False):
    """Make a proto library, possibly depending on other proto libraries."""
    _ignore = (js_api_version, js_codegen, provide_cc_alias)

    tf_proto_library_cc(
        name = name,
        testonly = testonly,
        srcs = srcs,
        cc_grpc_version = cc_grpc_version,
        cc_libs = cc_libs,
        default_header = default_header,
        protodeps = protodeps,
        visibility = visibility,
    )

    tf_proto_library_py(
        name = name,
        testonly = testonly,
        srcs = srcs,
        protodeps = protodeps,
        srcs_version = "PY2AND3",
        use_grpc_plugin = has_services,
        visibility = visibility,
    )

# A list of all files under platform matching the pattern in 'files'. In
# contrast with 'tf_platform_srcs' below, which seletive collects files that
# must be compiled in the 'default' platform, this is a list of all headers
# mentioned in the platform/* files.
def tf_platform_hdrs(files):
    return native.glob(["platform/*/" + f for f in files])

def tf_platform_srcs(files):
    base_set = ["platform/default/" + f for f in files]
    windows_set = base_set + ["platform/windows/" + f for f in files]
    posix_set = base_set + ["platform/posix/" + f for f in files]

    # Handle cases where we must also bring the posix file in. Usually, the list
    # of files to build on windows builds is just all the stuff in the
    # windows_set. However, in some cases the implementations in 'posix/' are
    # just what is necessary and historically we choose to simply use the posix
    # file instead of making a copy in 'windows'.
    for f in files:
        if f == "error.cc":
            windows_set.append("platform/posix/" + f)

    return select({
        "//tensorflow:windows": native.glob(windows_set),
        "//conditions:default": native.glob(posix_set),
    })

def tf_additional_lib_hdrs(exclude = []):
    windows_hdrs = native.glob([
        "platform/default/*.h",
        "platform/windows/*.h",
        "platform/posix/error.h",
    ], exclude = exclude)
    return select({
        "//tensorflow:windows": windows_hdrs,
        "//conditions:default": native.glob([
            "platform/default/*.h",
            "platform/posix/*.h",
        ], exclude = exclude),
    })

def tf_additional_lib_srcs(exclude = []):
    windows_srcs = native.glob([
        "platform/default/*.cc",
        "platform/windows/*.cc",
        "platform/posix/error.cc",
    ], exclude = exclude)
    return select({
        "//tensorflow:windows": windows_srcs,
        "//conditions:default": native.glob([
            "platform/default/*.cc",
            "platform/posix/*.cc",
        ], exclude = exclude),
    })

def tf_additional_minimal_lib_srcs():
    return [
        "platform/default/integral_types.h",
        "platform/default/mutex.h",
    ]

def tf_additional_proto_hdrs():
    return [
        "platform/default/integral_types.h",
        "platform/default/logging.h",
        "platform/default/protobuf.h",
    ] + if_windows([
        "platform/windows/integral_types.h",
    ])

def tf_additional_proto_compiler_hdrs():
    return [
        "platform/default/protobuf_compiler.h",
    ]

def tf_additional_proto_srcs():
    return [
        "platform/default/protobuf.cc",
    ]

def tf_additional_human_readable_json_deps():
    return []

def tf_additional_all_protos():
    return ["//tensorflow/core:protos_all"]

def tf_protos_all_impl():
    return ["//tensorflow/core:protos_all_cc_impl"]

def tf_protos_all():
    return if_static(
        extra_deps = tf_protos_all_impl(),
        otherwise = ["//tensorflow/core:protos_all_cc"],
    )

def tf_protos_grappler_impl():
    return ["//tensorflow/core/grappler/costs:op_performance_data_cc_impl"]

def tf_protos_grappler():
    return if_static(
        extra_deps = tf_protos_grappler_impl(),
        otherwise = ["//tensorflow/core/grappler/costs:op_performance_data_cc"],
    )

def tf_additional_cupti_wrapper_deps():
    return ["//tensorflow/core/platform/default/gpu:cupti_wrapper"]

def tf_additional_device_tracer_srcs():
    return ["platform/default/device_tracer.cc"]

def tf_additional_device_tracer_cuda_deps():
    return []

def tf_additional_device_tracer_deps():
    return []

def tf_additional_libdevice_data():
    return []

def tf_additional_libdevice_deps():
    return ["@local_config_cuda//cuda:cuda_headers"]

def tf_additional_libdevice_srcs():
    return ["platform/default/cuda_libdevice_path.cc"]

def tf_additional_test_deps():
    return []

def tf_additional_test_srcs():
    return [
        "platform/default/test_benchmark.cc",
    ] + select({
        "//tensorflow:windows": [
            "platform/windows/test.cc",
        ],
        "//conditions:default": [
            "platform/posix/test.cc",
        ],
    })

def tf_kernel_tests_linkstatic():
    return 0

def tf_additional_lib_defines():
    """Additional defines needed to build TF libraries."""
    return []

def tf_additional_lib_deps():
    """Additional dependencies needed to build TF libraries."""
    return [
        "@com_google_absl//absl/base:base",
        "@com_google_absl//absl/container:inlined_vector",
        "@com_google_absl//absl/types:span",
        "@com_google_absl//absl/types:optional",
    ] + if_static(
        ["@nsync//:nsync_cpp"],
        ["@nsync//:nsync_headers"],
    )

def tf_additional_core_deps():
     return select({
         "//tensorflow:android": [],
         "//tensorflow:ios": [],
         "//tensorflow:linux_s390x": [],
         "//tensorflow:windows": [],
         "//tensorflow:no_gcp_support": [],
         "//conditions:default": [
             "//tensorflow/core/platform/cloud:gcs_file_system",
         ],
     }) + select({
         "//tensorflow:android": [],
         "//tensorflow:ios": [],
         "//tensorflow:linux_s390x": [],
         "//tensorflow:windows": [],
         "//tensorflow:no_hdfs_support": [],
         "//conditions:default": [
             "//tensorflow/core/platform/hadoop:hadoop_file_system",
         ],
     }) + select({
         "//tensorflow:android": [],
         "//tensorflow:ios": [],
         "//tensorflow:linux_s390x": [],
         "//tensorflow:windows": [],
         "//tensorflow:no_aws_support": [],
         "//conditions:default": [
             "//tensorflow/core/platform/s3:s3_file_system",
         ],
     })

# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_op_deps():
    return select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_gcp_support": [],
        "//conditions:default": [
           "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
           "//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
       ],
   })

# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_kernel_deps():
    return select({
        "//tensorflow:android": [],
        "//tensorflow:windows": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//conditions:default": [
            "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
            "//tensorflow/contrib/cloud/kernels:gcs_config_ops",
        ],
    })

def tf_lib_proto_parsing_deps():
    return [
        ":protos_all_cc",
        "//third_party/eigen3",
        "//tensorflow/core/platform/default/build_config:proto_parsing",
    ]

def tf_lib_proto_compiler_deps():
    return [
        "@protobuf_archive//:protoc_lib",
    ]

def tf_additional_verbs_lib_defines():
    return select({
        "//tensorflow:with_verbs_support": ["TENSORFLOW_USE_VERBS"],
        "//conditions:default": [],
    })

def tf_additional_mpi_lib_defines():
    return select({
        "//tensorflow:with_mpi_support": ["TENSORFLOW_USE_MPI"],
        "//conditions:default": [],
    })

def tf_additional_gdr_lib_defines():
    return select({
        "//tensorflow:with_gdr_support": ["TENSORFLOW_USE_GDR"],
        "//conditions:default": [],
    })

def tf_py_clif_cc(name, visibility = None, **kwargs):
    pass

def tf_pyclif_proto_library(
        name,
        proto_lib,
        proto_srcfile = "",
        visibility = None,
        **kwargs):
    pass

def tf_additional_binary_deps():
    return ["@nsync//:nsync_cpp"] + if_cuda(
        [
            "//tensorflow/stream_executor:cuda_platform",
            "//tensorflow/core/platform/default/build_config:cuda",
        ],
    ) + [
        # TODO(allenl): Split these out into their own shared objects (they are
        # here because they are shared between contrib/ op shared objects and
        # core).
        "//tensorflow/core/kernels:lookup_util",
        "//tensorflow/core/util/tensor_bundle",
    ] + if_mkl_ml(
        [
            "//third_party/mkl:intel_binary_blob",
        ],
    )
  • tensorflow/tools/lib_package/BUILD
# Packaging for TensorFlow artifacts other than the Python API (pip whl).
# This includes the C API, Java API, and protocol buffer files.

package(default_visibility = ["//visibility:private"])

load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_tar")
load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib")
load("//tensorflow:tensorflow.bzl", "tf_binary_additional_srcs")
load("//tensorflow:tensorflow.bzl", "if_cuda")
load("//third_party/mkl:build_defs.bzl", "if_mkl")

genrule(
    name = "libtensorflow_proto",
    srcs = ["//tensorflow/core:protos_all_proto_srcs"],
    outs = ["libtensorflow_proto.zip"],
    cmd = "zip $@ $(SRCS)",
)

pkg_tar(
    name = "libtensorflow",
    extension = "tar.gz",
    # Mark as "manual" till
    # https://github.com/bazelbuild/bazel/issues/2352
    # and https://github.com/bazelbuild/bazel/issues/1580
    # are resolved, otherwise these rules break when built
    # with Python 3.
    tags = ["manual"],
    deps = [
        ":cheaders",
        ":clib",
        ":clicenses",
        ":eager_cheaders",
    ],
)

pkg_tar(
    name = "libtensorflow_jni",
    extension = "tar.gz",
    files = [
        "include/tensorflow/jni/LICENSE",
        "//tensorflow/java:libtensorflow_jni",
    ],
    # Mark as "manual" till
    # https://github.com/bazelbuild/bazel/issues/2352
    # and https://github.com/bazelbuild/bazel/issues/1580
    # are resolved, otherwise these rules break when built
    # with Python 3.
    tags = ["manual"],
    deps = [":common_deps"],
)

# Shared objects that all TensorFlow libraries depend on.
pkg_tar(
    name = "common_deps",
    files = tf_binary_additional_srcs(),
    tags = ["manual"],
)

pkg_tar(
    name = "cheaders",
    files = [
        "//tensorflow/c:headers",
    ],
    package_dir = "include/tensorflow/c",
    # Mark as "manual" till
    # https://github.com/bazelbuild/bazel/issues/2352
    # and https://github.com/bazelbuild/bazel/issues/1580
    # are resolved, otherwise these rules break when built
    # with Python 3.
    tags = ["manual"],
)

pkg_tar(
    name = "eager_cheaders",
    files = [
        "//tensorflow/c/eager:headers",
    ],
    package_dir = "include/tensorflow/c/eager",
    # Mark as "manual" till
    # https://github.com/bazelbuild/bazel/issues/2352
    # and https://github.com/bazelbuild/bazel/issues/1580
    # are resolved, otherwise these rules break when built
    # with Python 3.
    tags = ["manual"],
)

pkg_tar(
    name = "clib",
    files = ["//tensorflow:libtensorflow.so"],
    package_dir = "lib",
    # Mark as "manual" till
    # https://github.com/bazelbuild/bazel/issues/2352
    # and https://github.com/bazelbuild/bazel/issues/1580
    # are resolved, otherwise these rules break when built
    # with Python 3.
    tags = ["manual"],
    deps = [":common_deps"],
)

pkg_tar(
    name = "clicenses",
    files = [":include/tensorflow/c/LICENSE"],
    package_dir = "include/tensorflow/c",
    # Mark as "manual" till
    # https://github.com/bazelbuild/bazel/issues/2352
    # and https://github.com/bazelbuild/bazel/issues/1580
    # are resolved, otherwise these rules break when built
    # with Python 3.
    tags = ["manual"],
)

genrule(
    name = "clicenses_generate",
    srcs = [
        "//third_party/hadoop:LICENSE.txt",
        "//third_party/eigen3:LICENSE",
        "//third_party/fft2d:LICENSE",
        "@boringssl//:LICENSE",
        "@com_googlesource_code_re2//:LICENSE",
        "@curl//:COPYING",
        "@double_conversion//:LICENSE",
        "@eigen_archive//:COPYING.MPL2",
        "@farmhash_archive//:COPYING",
        "@fft2d//:fft/readme.txt",
        "@gemmlowp//:LICENSE",
        "@gif_archive//:COPYING",
        "@highwayhash//:LICENSE",
        "@icu//:icu4c/LICENSE",
        "@jpeg//:LICENSE.md",
        "@llvm//:LICENSE.TXT",
        "@lmdb//:LICENSE",
        "@local_config_sycl//sycl:LICENSE.text",
        "@nasm//:LICENSE",
        "@nsync//:LICENSE",
        "@png_archive//:LICENSE",
        "@protobuf_archive//:LICENSE",
        "@snappy//:COPYING",
        "@zlib_archive//:zlib.h",
    ] + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_aws_support": [],
        "//conditions:default": [
            "@aws//:LICENSE",
        ],
    }) + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_gcp_support": [],
        "//conditions:default": [
            "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
        ],
    }) + select({

        "//tensorflow/core/kernels:xsmm": [
            "@libxsmm_archive//:LICENSE.md",
        ],
        "//conditions:default": [],
    }) + if_cuda([
        "@cub_archive//:LICENSE.TXT",
    ]) + if_mkl([
        "//third_party/mkl:LICENSE",
        "//third_party/mkl_dnn:LICENSE",
    ]) + if_not_system_lib(
        "grpc",
        [
            "@grpc//:LICENSE",
            "@grpc//third_party/nanopb:LICENSE.txt",
            "@grpc//third_party/address_sorting:LICENSE",
        ],
    ),
    outs = ["include/tensorflow/c/LICENSE"],
    cmd = "$(location :concat_licenses.sh) $(SRCS) >$@",
    tools = [":concat_licenses.sh"],
)

genrule(
    name = "jnilicenses_generate",
    srcs = [
        "//third_party/hadoop:LICENSE.txt",
        "//third_party/eigen3:LICENSE",
        "//third_party/fft2d:LICENSE",
        "@boringssl//:LICENSE",
        "@com_googlesource_code_re2//:LICENSE",
        "@curl//:COPYING",
        "@double_conversion//:LICENSE",
        "@eigen_archive//:COPYING.MPL2",
        "@farmhash_archive//:COPYING",
        "@fft2d//:fft/readme.txt",
        "@gemmlowp//:LICENSE",
        "@gif_archive//:COPYING",
        "@highwayhash//:LICENSE",
        "@icu//:icu4j/main/shared/licenses/LICENSE",
        "@jpeg//:LICENSE.md",
        "@llvm//:LICENSE.TXT",
        "@lmdb//:LICENSE",
        "@local_config_sycl//sycl:LICENSE.text",
        "@nasm//:LICENSE",
        "@nsync//:LICENSE",
        "@png_archive//:LICENSE",
        "@protobuf_archive//:LICENSE",
        "@snappy//:COPYING",
        "@zlib_archive//:zlib.h",
    ] + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_aws_support": [],
        "//conditions:default": [
            "@aws//:LICENSE",
        ],
    }) + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_gcp_support": [],
        "//conditions:default": [
            "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
        ],
    }) + select({
        "//tensorflow/core/kernels:xsmm": [
            "@libxsmm_archive//:LICENSE.md",
        ],
        "//conditions:default": [],
    }) + if_cuda([
        "@cub_archive//:LICENSE.TXT",
    ]) + if_mkl([
        "//third_party/mkl:LICENSE",
        "//third_party/mkl_dnn:LICENSE",
    ]),
    outs = ["include/tensorflow/jni/LICENSE"],
    cmd = "$(location :concat_licenses.sh) $(SRCS) >$@",
    tools = [":concat_licenses.sh"],
)

sh_test(
    name = "libtensorflow_test",
    size = "small",
    srcs = ["libtensorflow_test.sh"],
    data = [
        "libtensorflow_test.c",
        ":libtensorflow.tar.gz",
    ],
    # Mark as "manual" till
    # https://github.com/bazelbuild/bazel/issues/2352
    # and https://github.com/bazelbuild/bazel/issues/1580
    # are resolved, otherwise these rules break when built
    # with Python 3.
    # Till then, this test is explicitly executed when building
    # the release by tensorflow/tools/ci_build/builds/libtensorflow.sh
    tags = ["manual"],
)

sh_test(
    name = "libtensorflow_java_test",
    size = "small",
    srcs = ["libtensorflow_java_test.sh"],
    data = [
        ":LibTensorFlowTest.java",
        ":libtensorflow_jni.tar.gz",
        "//tensorflow/java:libtensorflow.jar",
    ],
    # Mark as "manual" till
    # https://github.com/bazelbuild/bazel/issues/2352
    # and https://github.com/bazelbuild/bazel/issues/1580
    # are resolved, otherwise these rules break when built
    # with Python 3.
    # Till then, this test is explicitly executed when building
    # the release by tensorflow/tools/ci_build/builds/libtensorflow.sh
    tags = ["manual"],
)
  • tensorflow/tools/pip_package/BUILD
# Description:
#  Tools for building the TensorFlow pip package.

package(default_visibility = ["//visibility:private"])

load(
    "//tensorflow:tensorflow.bzl",
    "if_not_windows",
    "if_windows",
    "transitive_hdrs",
)
load("//third_party/mkl:build_defs.bzl", "if_mkl", "if_mkl_ml")
load("//tensorflow:tensorflow.bzl", "if_cuda")
load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib")
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
load(
    "//third_party/ngraph:build_defs.bzl",
    "if_ngraph",
)

# This returns a list of headers of all public header libraries (e.g.,
# framework, lib), and all of the transitive dependencies of those
# public headers.  Not all of the headers returned by the filegroup
# are public (e.g., internal headers that are included by public
# headers), but the internal headers need to be packaged in the
# pip_package for the public headers to be properly included.
#
# Public headers are therefore defined by those that are both:
#
# 1) "publicly visible" as defined by bazel
# 2) Have documentation.
#
# This matches the policy of "public" for our python API.
transitive_hdrs(
    name = "included_headers",
    deps = [
        "//tensorflow/core:core_cpu",
        "//tensorflow/core:framework",
        "//tensorflow/core:lib",
        "//tensorflow/core:protos_all_cc",
        "//tensorflow/core:stream_executor",
        "//third_party/eigen3",
    ] + if_cuda([
        "@local_config_cuda//cuda:cuda_headers",
    ]),
)

py_binary(
    name = "simple_console",
    srcs = ["simple_console.py"],
    srcs_version = "PY2AND3",
    deps = ["//tensorflow:tensorflow_py"],
)

COMMON_PIP_DEPS = [
    ":licenses",
    "MANIFEST.in",
    "README",
    "setup.py",
    ":included_headers",
    "//tensorflow:tensorflow_py",
    "//tensorflow/contrib/autograph:autograph",
    "//tensorflow/contrib/boosted_trees:boosted_trees_pip",
    "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
    "//tensorflow/contrib/compiler:xla",
    "//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
    "//tensorflow/contrib/eager/python/examples:examples_pip",
    "//tensorflow/contrib/eager/python:evaluator",
    "//tensorflow/contrib/gan:gan",
    "//tensorflow/contrib/graph_editor:graph_editor_pip",
    "//tensorflow/contrib/keras:keras",
    "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip",
    "//tensorflow/contrib/nn:nn_py",
    "//tensorflow/contrib/predictor:predictor_pip",
    "//tensorflow/contrib/proto:proto",
    "//tensorflow/contrib/receptive_field:receptive_field_pip",
    "//tensorflow/contrib/rate:rate",
    "//tensorflow/contrib/rpc:rpc_pip",
    "//tensorflow/contrib/session_bundle:session_bundle_pip",
    "//tensorflow/contrib/signal:signal_py",
    "//tensorflow/contrib/signal:test_util",
    "//tensorflow/contrib/slim:slim",
    "//tensorflow/contrib/slim/python/slim/data:data_pip",
    "//tensorflow/contrib/slim/python/slim/nets:nets_pip",
    "//tensorflow/contrib/specs:specs",
    "//tensorflow/contrib/summary:summary_test_util",
    "//tensorflow/contrib/tensor_forest:init_py",
    "//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip",
    "//tensorflow/contrib/timeseries:timeseries_pip",
    "//tensorflow/contrib/tpu",
    "//tensorflow/examples/tutorials/mnist:package",
    # "//tensorflow/python/autograph/converters:converters",
    # "//tensorflow/python/autograph/core:core",
    "//tensorflow/python/autograph/core:test_lib",
    # "//tensorflow/python/autograph/impl:impl",
    # "//tensorflow/python/autograph/lang:lang",
    # "//tensorflow/python/autograph/operators:operators",
    # "//tensorflow/python/autograph/pyct:pyct",
    # "//tensorflow/python/autograph/pyct/testing:testing",
    # "//tensorflow/python/autograph/pyct/static_analysis:static_analysis",
    "//tensorflow/python/autograph/pyct/common_transformers:common_transformers",
    "//tensorflow/python:cond_v2",
    "//tensorflow/python:distributed_framework_test_lib",
    "//tensorflow/python:meta_graph_testdata",
    "//tensorflow/python:spectral_ops_test_util",
    "//tensorflow/python:util_example_parser_configuration",
    "//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base",
    "//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
    "//tensorflow/python/data/kernel_tests:test_base",
    "//tensorflow/python/debug:debug_pip",
    "//tensorflow/python/eager:eager_pip",
    "//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files",
    "//tensorflow/python/saved_model:saved_model",
    "//tensorflow/python/tools:tools_pip",
    "//tensorflow/python/tools/api/generator:create_python_api",
    "//tensorflow/python:test_ops",
    "//tensorflow/python:while_v2",
    "//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
]

# On Windows, python binary is a zip file of runfiles tree.
# Add everything to its data dependency for generating a runfiles tree
# for building the pip package on Windows.
py_binary(
    name = "simple_console_for_windows",
    srcs = ["simple_console_for_windows.py"],
    data = COMMON_PIP_DEPS,
    srcs_version = "PY2AND3",
    deps = ["//tensorflow:tensorflow_py"],
)

filegroup(
    name = "licenses",
    data = [
        "//third_party/eigen3:LICENSE",
        "//third_party/fft2d:LICENSE",
        "//third_party/hadoop:LICENSE.txt",
        "@absl_py//absl/flags:LICENSE",
        "@arm_neon_2_x86_sse//:LICENSE",
        "@astor_archive//:LICENSE",
        "@boringssl//:LICENSE",
        "@com_google_absl//:LICENSE",
        "@com_googlesource_code_re2//:LICENSE",
        "@curl//:COPYING",
        "@double_conversion//:LICENSE",
        "@eigen_archive//:COPYING.MPL2",
        "@farmhash_archive//:COPYING",
        "@fft2d//:fft/readme.txt",
        "@flatbuffers//:LICENSE.txt",
        "@gast_archive//:PKG-INFO",
        "@gemmlowp//:LICENSE",
        "@gif_archive//:COPYING",
        "@highwayhash//:LICENSE",
        "@icu//:icu4c/LICENSE",
        "@jpeg//:LICENSE.md",
        "@lmdb//:LICENSE",
        "@local_config_sycl//sycl:LICENSE.text",
        "@nasm//:LICENSE",
        "@nsync//:LICENSE",
        "@pcre//:LICENCE",
        "@png_archive//:LICENSE",
        "@protobuf_archive//:LICENSE",
        "@six_archive//:LICENSE",
        "@snappy//:COPYING",
        "@swig//:LICENSE",
        "@termcolor_archive//:COPYING.txt",
        "@zlib_archive//:zlib.h",
        "@org_python_pypi_backports_weakref//:LICENSE",
    ] + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_aws_support": [],
        "//conditions:default": [
            "@aws//:LICENSE",
        ],
    }) + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_gcp_support": [],
        "//conditions:default": [
            "@com_github_googleapis_googleapis//:LICENSE",
            "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
        ],
    }) + select({
        "//tensorflow:android": [],
        "//tensorflow:ios": [],
        "//tensorflow:linux_s390x": [],
        "//tensorflow:windows": [],
        "//tensorflow:no_kafka_support": [],
        "//conditions:default": [
            "@kafka//:LICENSE",
        ],
    }) + select({
        "//tensorflow/core/kernels:xsmm": [
            "@libxsmm_archive//:LICENSE.md",
        ],
        "//conditions:default": [],
    }) + if_cuda([
        "@cub_archive//:LICENSE.TXT",
        "@local_config_nccl//:LICENSE",
    ]) + if_mkl([
        "//third_party/mkl:LICENSE",
        "//third_party/mkl_dnn:LICENSE",
    ]) + if_not_system_lib(
        "grpc",
        [
            "@grpc//:LICENSE",
            "@grpc//third_party/nanopb:LICENSE.txt",
            "@grpc//third_party/address_sorting:LICENSE",
        ],
    ) + if_ngraph([
        "@ngraph//:LICENSE",
        "@ngraph_tf//:LICENSE",
        "@nlohmann_json_lib//:LICENSE.MIT",
        "@tbb//:LICENSE",
    ]) + tf_additional_license_deps(),
)

sh_binary(
    name = "build_pip_package",
    srcs = ["build_pip_package.sh"],
    data = select({
        "//tensorflow:windows": [
            ":simple_console_for_windows",
            "//tensorflow/contrib/lite/python:interpreter_test_data",
            "//tensorflow/contrib/lite/python:tflite_convert",
            "//tensorflow/contrib/lite/toco/python:toco_from_protos",
        ],
        "//conditions:default": COMMON_PIP_DEPS + [
            ":simple_console",
            "//tensorflow/contrib/lite/python:interpreter_test_data",
            "//tensorflow/contrib/lite/python:tflite_convert",
            "//tensorflow/contrib/lite/toco/python:toco_from_protos",
        ],
    }) + if_mkl_ml(["//third_party/mkl:intel_binary_blob"]),
)

# A genrule for generating a marker file for the pip package on Windows
#
# This only works on Windows, because :simple_console_for_windows is a
# python zip file containing everything we need for building the pip package.
# However, on other platforms, due to https://github.com/bazelbuild/bazel/issues/4223,
# when C++ extensions change, this generule doesn't rebuild.
genrule(
    name = "win_pip_package_marker",
    srcs = if_windows([
        ":build_pip_package",
        ":simple_console_for_windows",
    ]),
    outs = ["win_pip_package_marker_file"],
    cmd = select({
        "//conditions:default": "touch $@",
        "//tensorflow:windows": "md5sum $(locations :build_pip_package) $(locations :simple_console_for_windows) > $@",
    }),
    visibility = ["//visibility:public"],
)
$ sudo nano /etc/dphys-swapfile
CONF_SWAPFILE=2048
CONF_MAXSWAP=2048

$ sudo systemctl stop dphys-swapfile
$ sudo systemctl start dphys-swapfile

$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/zram.sh
$ chmod 755 zram.sh
$ sudo mv zram.sh /etc/init.d/
$ sudo update-rc.d zram.sh defaults
$ sudo reboot

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev
$ sudo pip3 install keras_applications==1.0.7 --no-deps
$ sudo pip3 install keras_preprocessing==1.0.9 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six numpy wheel mock

$ cd ~
$ git clone https://github.com/tensorflow/tensorflow.git
$ cd tensorflow

$ ./configure
WARNING: Processed legacy workspace file /home/pi/work/tensorflow/tools/bazel.rc. This file will not be processed in the next release of Bazel. Please read https://github.com/bazelbuild/bazel/issues/6319 for further information, including how to upgrade.
WARNING: Running Bazel server needs to be killed, because the startup options are different.
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.18.1- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /home/pi/inference_engine_vpu_arm/python/python3.5/armv7l
  /usr/local/lib
  /home/pi/inference_engine_vpu_arm/python/python3.5
  /usr/local/lib/python3.5/dist-packages
  /usr/lib/python3/dist-packages
Please input the desired Python library path to use.  Default is [/home/pi/inference_engine_vpu_arm/python/python3.5/armv7l]
/usr/local/lib/python3.5/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
	--config=mkl         	# Build with MKL support.
	--config=monolithic  	# Config for mostly static monolithic build.
	--config=gdr         	# Build with GDR support.
	--config=verbs       	# Build with libverbs support.
	--config=ngraph      	# Build with Intel nGraph support.
Preconfigured Bazel build configs to DISABLE default on features:
	--config=noaws       	# Disable AWS S3 filesystem support.
	--config=nogcp       	# Disable GCP support.
	--config=nohdfs      	# Disable HDFS support.
	--config=noignite    	# Disable Apacha Ignite support.
	--config=nokafka     	# Disable Apache Kafka support.
Configuration finished

tensorflow/tensorflow#22819 https://github.com/tensorflow/tensorflow/commit/d80eb525e94763e09cbb9fa3cbef9a0f64e2cb2a https://github.com/tensorflow/tensorflow/commit/5847293aeb9ab45a02c4231c40569a15bd4541c6 tensorflow/tensorflow#23721 tensorflow/tensorflow#25748 tensorflow/tensorflow#25120 (comment) https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/pip_package tensorflow/tensorflow#24372 https://gist.github.com/fyhertz/4cef0b696b37d38964801d3ef21e8ce2

$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nogcp \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--local_resources=1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
Tensorflow v1.13.1

============================================================

Tensorflow v1.13.1 - Bazel 0.19.2

============================================================

Python3.x

$ sudo nano /etc/dphys-swapfile
CONF_SWAPFILE=2048
CONF_MAXSWAP=2048

$ sudo systemctl stop dphys-swapfile
$ sudo systemctl start dphys-swapfile

$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/zram.sh
$ chmod 755 zram.sh
$ sudo mv zram.sh /etc/init.d/
$ sudo update-rc.d zram.sh defaults
$ sudo reboot

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev
$ sudo pip3 install keras_applications==1.0.7 --no-deps
$ sudo pip3 install keras_preprocessing==1.0.9 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six numpy wheel mock
$ sudo apt update;sudo apt upgrade

$ cd ~
$ git clone https://github.com/PINTO0309/Bazel_bin.git
$ cd Bazel_bin
$ ./0.19.2/Raspbian_armhf/install.sh

$ cd ~
$ git clone https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
  • tensorflow/lite/python/interpreter.py
import sys
import numpy as np

# pylint: disable=g-import-not-at-top
try:
  from tensorflow.python.util.lazy_loader import LazyLoader
  from tensorflow.python.util.tf_export import tf_export as _tf_export

  # Lazy load since some of the performance benchmark skylark rules
  # break dependencies. Must use double quotes to match code internal rewrite
  # rule.
  # pylint: disable=g-inconsistent-quotes
  _interpreter_wrapper = LazyLoader(
      "_interpreter_wrapper", globals(),
      "tensorflow.lite.python.interpreter_wrapper."
      "tensorflow_wrap_interpreter_wrapper")
  # pylint: enable=g-inconsistent-quotes

  del LazyLoader
except ImportError:
  # When full Tensorflow Python PIP is not available do not use lazy load
  # and instead uf the tflite_runtime path.
  from tflite_runtime.lite.python import interpreter_wrapper as _interpreter_wrapper

  def tf_export_dummy(*x, **kwargs):
    del x, kwargs
    return lambda x: x
  _tf_export = tf_export_dummy


@_tf_export('lite.Interpreter')
class Interpreter(object):
  """Interpreter inferace for TF-Lite Models."""

  def __init__(self, model_path=None, model_content=None):
    """Constructor.
    Args:
      model_path: Path to TF-Lite Flatbuffer file.
      model_content: Content of model.
    Raises:
      ValueError: If the interpreter was unable to create.
    """
    if model_path and not model_content:
      self._interpreter = (
          _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile(
              model_path))
      if not self._interpreter:
        raise ValueError('Failed to open {}'.format(model_path))
    elif model_content and not model_path:
      # Take a reference, so the pointer remains valid.
      # Since python strings are immutable then PyString_XX functions
      # will always return the same pointer.
      self._model_content = model_content
      self._interpreter = (
          _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer(
              model_content))
    elif not model_path and not model_path:
      raise ValueError('`model_path` or `model_content` must be specified.')
    else:
      raise ValueError('Can\'t both provide `model_path` and `model_content`')

  def allocate_tensors(self):
    self._ensure_safe()
    return self._interpreter.AllocateTensors()

  def _safe_to_run(self):
    """Returns true if there exist no numpy array buffers.
    This means it is safe to run tflite calls that may destroy internally
    allocated memory. This works, because in the wrapper.cc we have made
    the numpy base be the self._interpreter.
    """
    # NOTE, our tensor() call in cpp will use _interpreter as a base pointer.
    # If this environment is the only _interpreter, then the ref count should be
    # 2 (1 in self and 1 in temporary of sys.getrefcount).
    return sys.getrefcount(self._interpreter) == 2

  def _ensure_safe(self):
    """Makes sure no numpy arrays pointing to internal buffers are active.
    This should be called from any function that will call a function on
    _interpreter that may reallocate memory e.g. invoke(), ...
    Raises:
      RuntimeError: If there exist numpy objects pointing to internal memory
        then we throw.
    """
    if not self._safe_to_run():
      raise RuntimeError("""There is at least 1 reference to internal data
      in the interpreter in the form of a numpy array or slice. Be sure to
      only hold the function returned from tensor() if you are using raw
      data access.""")

  def _get_tensor_details(self, tensor_index):
    """Gets tensor details.
    Args:
      tensor_index: Tensor index of tensor to query.
    Returns:
      a dictionary containing the name, index, shape and type of the tensor.
    Raises:
      ValueError: If tensor_index is invalid.
    """
    tensor_index = int(tensor_index)
    tensor_name = self._interpreter.TensorName(tensor_index)
    tensor_size = self._interpreter.TensorSize(tensor_index)
    tensor_type = self._interpreter.TensorType(tensor_index)
    tensor_quantization = self._interpreter.TensorQuantization(tensor_index)

    if not tensor_name or not tensor_type:
      raise ValueError('Could not get tensor details')

    details = {
        'name': tensor_name,
        'index': tensor_index,
        'shape': tensor_size,
        'dtype': tensor_type,
        'quantization': tensor_quantization,
    }

    return details

  def get_tensor_details(self):
    """Gets tensor details for every tensor with valid tensor details.
    Tensors where required information about the tensor is not found are not
    added to the list. This includes temporary tensors without a name.
    Returns:
      A list of dictionaries containing tensor information.
    """
    tensor_details = []
    for idx in range(self._interpreter.NumTensors()):
      try:
        tensor_details.append(self._get_tensor_details(idx))
      except ValueError:
        pass
    return tensor_details

  def get_input_details(self):
    """Gets model input details.
    Returns:
      A list of input details.
    """
    return [
        self._get_tensor_details(i) for i in self._interpreter.InputIndices()
    ]

  def set_tensor(self, tensor_index, value):
    """Sets the value of the input tensor. Note this copies data in `value`.
    If you want to avoid copying, you can use the `tensor()` function to get a
    numpy buffer pointing to the input buffer in the tflite interpreter.
    Args:
      tensor_index: Tensor index of tensor to set. This value can be gotten from
                    the 'index' field in get_input_details.
      value: Value of tensor to set.
    Raises:
      ValueError: If the interpreter could not set the tensor.
    """
    self._interpreter.SetTensor(tensor_index, value)

  def resize_tensor_input(self, input_index, tensor_size):
    """Resizes an input tensor.
    Args:
      input_index: Tensor index of input to set. This value can be gotten from
                   the 'index' field in get_input_details.
      tensor_size: The tensor_shape to resize the input to.
    Raises:
      ValueError: If the interpreter could not resize the input tensor.
    """
    self._ensure_safe()
    # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size
    # parameter.
    tensor_size = np.array(tensor_size, dtype=np.int32)
    self._interpreter.ResizeInputTensor(input_index, tensor_size)

  def get_output_details(self):
    """Gets model output details.
    Returns:
      A list of output details.
    """
    return [
        self._get_tensor_details(i) for i in self._interpreter.OutputIndices()
    ]

  def get_tensor(self, tensor_index):
    """Gets the value of the input tensor (get a copy).
    If you wish to avoid the copy, use `tensor()`. This function cannot be used
    to read intermediate results.
    Args:
      tensor_index: Tensor index of tensor to get. This value can be gotten from
                    the 'index' field in get_output_details.
    Returns:
      a numpy array.
    """
    return self._interpreter.GetTensor(tensor_index)

  def tensor(self, tensor_index):
    """Returns function that gives a numpy view of the current tensor buffer.
    This allows reading and writing to this tensors w/o copies. This more
    closely mirrors the C++ Interpreter class interface's tensor() member, hence
    the name. Be careful to not hold these output references through calls
    to `allocate_tensors()` and `invoke()`. This function cannot be used to read
    intermediate results.
    Usage:
    ```
    interpreter.allocate_tensors()
    input = interpreter.tensor(interpreter.get_input_details()[0]["index"])
    output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
    for i in range(10):
      input().fill(3.)
      interpreter.invoke()
      print("inference %s" % output())
    ```
    Notice how this function avoids making a numpy array directly. This is
    because it is important to not hold actual numpy views to the data longer
    than necessary. If you do, then the interpreter can no longer be invoked,
    because it is possible the interpreter would resize and invalidate the
    referenced tensors. The NumPy API doesn't allow any mutability of the
    the underlying buffers.
    WRONG:
    ```
    input = interpreter.tensor(interpreter.get_input_details()[0]["index"])()
    output = interpreter.tensor(interpreter.get_output_details()[0]["index"])()
    interpreter.allocate_tensors()  # This will throw RuntimeError
    for i in range(10):
      input.fill(3.)
      interpreter.invoke()  # this will throw RuntimeError since input,output
    ```
    Args:
      tensor_index: Tensor index of tensor to get. This value can be gotten from
                    the 'index' field in get_output_details.
    Returns:
      A function that can return a new numpy array pointing to the internal
      TFLite tensor state at any point. It is safe to hold the function forever,
      but it is not safe to hold the numpy array forever.
    """
    return lambda: self._interpreter.tensor(self._interpreter, tensor_index)

  def invoke(self):
    """Invoke the interpreter.
    Be sure to set the input sizes, allocate tensors and fill values before
    calling this.
    Raises:
      ValueError: When the underlying interpreter fails raise ValueError.
    """
    self._ensure_safe()
    self._interpreter.Invoke()

  def reset_all_variables(self):
    return self._interpreter.ResetVariableTensors()

  def set_num_threads(self, i):
    """Set number of threads used by TFLite kernels.
    If not set, kernels are running single-threaded. Note that currently,
    only some kernels, such as conv, are multithreaded.
    Args:
      i: number of threads.
    """
    return self._interpreter.SetNumThreads(i)
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.
  • tensorflow/tensorflow/core/kernels/BUILD
cc_library(
    name = "linalg",
    deps = [
        ":cholesky_grad",
        ":cholesky_op",
        ":determinant_op",
        ":lu_op",
        ":matrix_exponential_op",
        ":matrix_inverse_op",
        ":matrix_logarithm_op",
        ":matrix_solve_ls_op",
        ":matrix_solve_op",
        ":matrix_triangular_solve_op",
        ":qr_op",
        ":self_adjoint_eig_op",
        ":self_adjoint_eig_v2_op",
        ":svd_op",
    ],
)
  • tensorflow/tensorflow/core/kernels/BUILD - Delete the following
tf_kernel_library(
    name = "matrix_square_root_op",
    prefix = "matrix_square_root_op",
    deps = LINALG_DEPS,
)
  • tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false
ifeq ($(BUILD_WITH_NNAPI),true)
	CORE_CC_EXCLUDE_SRCS += tensorflow/lite/nnapi_delegate_disabled.cc
else
	CORE_CC_EXCLUDE_SRCS += tensorflow/lite/nnapi_delegate.cc
endif

ifeq ($(TARGET),ios)
	CORE_CC_EXCLUDE_SRCS += tensorflow/lite/minimal_logging_android.cc
	CORE_CC_EXCLUDE_SRCS += tensorflow/lite/minimal_logging_default.cc
else
	CORE_CC_EXCLUDE_SRCS += tensorflow/lite/minimal_logging_android.cc
	CORE_CC_EXCLUDE_SRCS += tensorflow/lite/minimal_logging_ios.cc
endif
  • configure
$ ./configure
Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.19.2- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /home/b920405/git/caffe-jacinto/python
  /opt/intel//computer_vision_sdk_2018.5.455/python/python3.5/ubuntu16
  /opt/intel//computer_vision_sdk_2018.5.455/python/python3.5
  .
  /opt/intel//computer_vision_sdk_2018.5.455/deployment_tools/model_optimizer
  /opt/movidius/caffe/python
  /usr/lib/python3/dist-packages
  /usr/local/lib/python3.5/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib]
/usr/local/lib/python3.5/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
    --config=mkl            # Build with MKL support.
    --config=monolithic     # Config for mostly static monolithic build.
    --config=gdr            # Build with GDR support.
    --config=verbs          # Build with libverbs support.
    --config=ngraph         # Build with Intel nGraph support.
    --config=dynamic_kernels    # (Experimental) Build kernels into separate shared objects.
Preconfigured Bazel build configs to DISABLE default on features:
    --config=noaws          # Disable AWS S3 filesystem support.
    --config=nogcp          # Disable GCP support.
    --config=nohdfs         # Disable HDFS support.
    --config=noignite       # Disable Apache Ignite support.
    --config=nokafka        # Disable Apache Kafka support.
    --config=nonccl         # Disable NVIDIA NCCL support.
Configuration finished
  • build
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nogcp \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--local_resources=1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-1.13.1-cp35-cp35m-linux_arm7l.whl ~
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-1.13.1-cp35-cp35m-linux_armv7l.whl
Tensorflow v1.14.0

============================================================

Tensorflow v1.14.0 - Bazel 0.24.1 - Stretch - armhf

============================================================

$ sudo nano /etc/dphys-swapfile
CONF_SWAPFILE=2048
CONF_MAXSWAP=2048

$ sudo systemctl stop dphys-swapfile
$ sudo systemctl start dphys-swapfile

$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/zram.sh
$ chmod 755 zram.sh
$ sudo mv zram.sh /etc/init.d/
$ sudo update-rc.d zram.sh defaults
$ sudo reboot

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev openjdk-8-jdk

$ sudo pip3 install keras_applications==1.0.7 --no-deps
$ sudo pip3 install keras_preprocessing==1.0.9 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six numpy wheel mock
$ sudo apt update;sudo apt upgrade

$ cd ~
$ git clone https://github.com/PINTO0309/Bazel_bin.git
$ cd Bazel_bin
$ ./0.24.1/Raspbian_Stretch_armhf/install.sh

$ cd ~
$ git clone -b v1.14.0 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v1.14.0
  • tensorflow/lite/python/interpreter.py
# Add the following two lines to the last line
  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.
  • tensorflow/tensorflow/core/kernels/BUILD
cc_library(
    name = "linalg",
    deps = [
        ":cholesky_grad",
        ":cholesky_op",
        ":determinant_op",
        ":lu_op",
        ":matrix_exponential_op",
        ":matrix_inverse_op",
        ":matrix_logarithm_op",
        ":matrix_solve_ls_op",
        ":matrix_solve_op",
        ":matrix_triangular_solve_op",
        ":qr_op",
        ":self_adjoint_eig_op",
        ":self_adjoint_eig_v2_op",
        ":svd_op",
        ":tridiagonal_solve_op",
    ],
)
  • tensorflow/tensorflow/core/kernels/BUILD - Delete the following
tf_kernel_library(
    name = "matrix_square_root_op",
    prefix = "matrix_square_root_op",
    deps = LINALG_DEPS,
)
  • tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false
  • configure
$ ./configure
Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.24.1- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /home/pi/inference_engine_vpu_arm/python/python3.5
  /usr/local/lib/python3.5/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib]
/usr/local/lib/python3.5/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
	--config=mkl         	# Build with MKL support.
	--config=monolithic  	# Config for mostly static monolithic build.
	--config=gdr         	# Build with GDR support.
	--config=verbs       	# Build with libverbs support.
	--config=ngraph      	# Build with Intel nGraph support.
	--config=numa        	# Build with NUMA support.
	--config=dynamic_kernels	# (Experimental) Build kernels into separate shared objects.
Preconfigured Bazel build configs to DISABLE default on features:
	--config=noaws       	# Disable AWS S3 filesystem support.
	--config=nogcp       	# Disable GCP support.
	--config=nohdfs      	# Disable HDFS support.
	--config=noignite    	# Disable Apache Ignite support.
	--config=nokafka     	# Disable Apache Kafka support.
	--config=nonccl      	# Disable NVIDIA NCCL support.
Configuration finished
  • build
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nogcp \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--local_resources=1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-1.14.0-cp35-cp35m-linux_arm7l.whl ~
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-1.14.0-cp35-cp35m-linux_armv7l.whl

============================================================

Tensorflow v1.14.0 - Bazel 0.24.1 - Buster - armhf

============================================================ First, prepare an emulation environment for armhf with QEMU 4.0.0. (CPU 4core, RAM 4GB) How to create a Debian Buster armhf OS image from scratch in hardware emulation mode of QEMU 4.0.0 (Kernel 4.19.0-5-armmp-lpae, for building Tensorflow armhf)

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev openjdk-11-jdk

$ sudo pip3 install keras_applications==1.0.7 --no-deps
$ sudo pip3 install keras_preprocessing==1.0.9 --no-deps
$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/packages/numpy-1.16.4-cp37-cp37m-linux_armv7l.whl
$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/packages/h5py-2.9.0-cp37-cp37m-linux_armv7l.whl
$ sudo pip3 install numpy-1.16.4-cp37-cp37m-linux_armv7l.whl
$ sudo pip3 install h5py-2.9.0-cp37-cp37m-linux_armv7l.whl
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six wheel mock
$ sudo apt update;sudo apt upgrade

$ cd ~
$ git clone https://github.com/PINTO0309/Bazel_bin.git
$ cd Bazel_bin
$ ./0.24.1/Raspbian_Buster_armhf/install.sh

$ cd ~
$ git clone -b v1.14.0 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v1.14.0
  • tensorflow/lite/python/interpreter.py
# Add the following two lines to the last line
  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.
  • tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false
  • tensorflow/contrib/__init__.py
from tensorflow.contrib import checkpoint
#if os.name != "nt" and platform.machine() != "s390x":
#  from tensorflow.contrib import cloud
from tensorflow.contrib import cluster_resolver
  • tensorflow/contrib/__init__.py
from tensorflow.contrib.summary import summary

if os.name != "nt" and platform.machine() != "s390x":
  try:
    from tensorflow.contrib import cloud
  except ImportError:
    pass

from tensorflow.python.util.lazy_loader import LazyLoader
ffmpeg = LazyLoader("ffmpeg", globals(),
                    "tensorflow.contrib.ffmpeg")
  • configure
$ ./configure
Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.24.1- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /home/pi/inference_engine_vpu_arm/python/python3.5
  /usr/local/lib/python3.5/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib]
/usr/local/lib/python3.5/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
	--config=mkl         	# Build with MKL support.
	--config=monolithic  	# Config for mostly static monolithic build.
	--config=gdr         	# Build with GDR support.
	--config=verbs       	# Build with libverbs support.
	--config=ngraph      	# Build with Intel nGraph support.
	--config=numa        	# Build with NUMA support.
	--config=dynamic_kernels	# (Experimental) Build kernels into separate shared objects.
Preconfigured Bazel build configs to DISABLE default on features:
	--config=noaws       	# Disable AWS S3 filesystem support.
	--config=nogcp       	# Disable GCP support.
	--config=nohdfs      	# Disable HDFS support.
	--config=noignite    	# Disable Apache Ignite support.
	--config=nokafka     	# Disable Apache Kafka support.
	--config=nonccl      	# Disable NVIDIA NCCL support.
Configuration finished
  • build
$ sudo bazel build \
--config=opt \
--config=noaws \
--config=nogcp \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--local_resources=4096.0,2.0,1.0 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-1.14.0-cp37-cp37m-linux_arm7l.whl ~
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-1.14.0-cp37-cp37m-linux_armv7l.whl

============================================================

Tensorflow v1.14.0 - Bazel 0.24.1 - Buster - aarch64

============================================================

First, prepare an emulation environment for aarch64 with QEMU 4.0.0. How to create a Debian Buster aarch64 OS image from scratch in QEMU 4.0.0 hardware emulation mode (Kernel 4.19.0-5-arm64, for Tensorflow aarch64 build)

Next, build Bazel and Tensorflow according to the following procedure in the emulator environment.

$ sudo apt-get install -y \
libhdf5-dev libc-ares-dev libeigen3-dev \
libatlas3-base net-tools build-essential \
zip unzip python3-pip curl wget git zip unzip
$ sudo pip3 install pip --upgrade
$ sudo pip3 install zipper
$ sudo pip3 install keras_applications==1.0.7 --no-deps
$ sudo pip3 install keras_preprocessing==1.0.9 --no-deps
$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/packages/absl_py-0.7.1-cp37-none-any.whl
$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/packages/gast-0.2.2-cp37-none-any.whl
$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/packages/grpcio-1.21.1-cp37-cp37m-linux_aarch64.whl
$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/packages/h5py-2.9.0-cp37-cp37m-linux_aarch64.whl
$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/packages/numpy-1.16.4-cp37-cp37m-linux_aarch64.whl
$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/packages/wrapt-1.11.2-cp37-cp37m-linux_aarch64.whl
$ sudo pip3 install *.whl
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo pip3 install -U --user mock zipper wheel

$ sudo apt-get update
$ sudo apt-get remove -y openjdk-8* --purge
$ sudo apt-get install -y openjdk-11-jdk

$ cd ~
$ mkdir bazel;cd bazel
$ wget https://github.com/bazelbuild/bazel/releases/download/0.24.1/bazel-0.24.1-dist.zip
$ unzip bazel-0.24.1-dist.zip
$ env EXTRA_BAZEL_ARGS="--host_javabase=@local_jdk//:jdk"

$ nano compile.sh

#################################################################################
bazel_build "src:bazel_nojdk${EXE_EXT}" \
  --action_env=PATH \
  --host_platform=@bazel_tools//platforms:host_platform \
  --platforms=@bazel_tools//platforms:target_platform \
  || fail "Could not build Bazel"
##################################################################################################################################################################
bazel_build "src:bazel_nojdk${EXE_EXT}" \
  --host_javabase=@local_jdk//:jdk \
  --action_env=PATH \
  --host_platform=@bazel_tools//platforms:host_platform \
  --platforms=@bazel_tools//platforms:target_platform \
  || fail "Could not build Bazel"
#################################################################################

$ sudo bash ./compile.sh
$ sudo cp output/bazel /usr/local/bin

$ bazel version
Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
Build label: 0.24.1- (@non-git)
Build target: bazel-out/aarch64-opt/bin/src/main/java/com/google/devtools/build/lib/bazel/BazelServer_deploy.jar
Build time: Sun Jun 23 20:46:48 2019 (1561322808)
Build timestamp: 1561322808
Build timestamp as int: 1561322808

$ cd ~
$ git clone -b v1.14.0 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v1.14.0
  • tensorflow/lite/python/interpreter.py
# Add the following two lines to the last line
  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.
  • tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false
# Settings for generic aarch64 boards such as Odroid C2 or Pine64.
ifeq ($(TARGET),aarch64)
  # The aarch64 architecture covers all 64-bit ARM chips. This arch mandates
  # NEON, so FPU flags are not needed below.
  TARGET_ARCH := armv8-a
  TARGET_TOOLCHAIN_PREFIX := aarch64-linux-gnu-

  CXXFLAGS += \
    -march=armv8-a \
    -funsafe-math-optimizations \
    -ftree-vectorize \
    -flax-vector-conversions \
    -fomit-frame-pointer \
    -fPIC

  CFLAGS += \
    -march=armv8-a \
    -funsafe-math-optimizations \
    -ftree-vectorize \
    -flax-vector-conversions \
    -fomit-frame-pointer \
    -fPIC

  LDFLAGS := \
    -Wl,--no-export-dynamic \
    -Wl,--exclude-libs,ALL \
    -Wl,--gc-sections \
    -Wl,--as-needed


  LIBS := \
    -lstdc++ \
    -lpthread \
    -lm \
    -ldl \
    -lrt

endif
            "/DTF_COMPILE_LIBRARY",
            "/wd4018",  # -Wno-sign-compare
        ],
+       str(Label("//tensorflow:linux_aarch64")): [
+           "-flax-vector-conversions",
+           "-fomit-frame-pointer",
+       ],
        "//conditions:default": [
            "-Wno-sign-compare",
        ],
  • tensorflow/contrib/__init__.py
from tensorflow.contrib import checkpoint
#if os.name != "nt" and platform.machine() != "s390x":
#  from tensorflow.contrib import cloud
from tensorflow.contrib import cluster_resolver
  • tensorflow/contrib/__init__.py
from tensorflow.contrib.summary import summary

if os.name != "nt" and platform.machine() != "s390x":
  try:
    from tensorflow.contrib import cloud
  except ImportError:
    pass

from tensorflow.python.util.lazy_loader import LazyLoader
ffmpeg = LazyLoader("ffmpeg", globals(),
                    "tensorflow.contrib.ffmpeg")
$ ./configure
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib/python3.7/dist-packages
  /usr/lib/python3/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib/python3.7/dist-packages]

Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
	--config=mkl         	# Build with MKL support.
	--config=monolithic  	# Config for mostly static monolithic build.
	--config=gdr         	# Build with GDR support.
	--config=verbs       	# Build with libverbs support.
	--config=ngraph      	# Build with Intel nGraph support.
	--config=numa        	# Build with NUMA support.
	--config=dynamic_kernels	# (Experimental) Build kernels into separate shared objects.
Preconfigured Bazel build configs to DISABLE default on features:
	--config=noaws       	# Disable AWS S3 filesystem support.
	--config=nogcp       	# Disable GCP support.
	--config=nohdfs      	# Disable HDFS support.
	--config=noignite    	# Disable Apache Ignite support.
	--config=nokafka     	# Disable Apache Kafka support.
	--config=nonccl      	# Disable NVIDIA NCCL support.
Configuration finished
$ sudo bazel build \
--config=opt \
--config=noaws \
--config=nogcp \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--local_resources=8192.0,4.0,1.0 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-flax-vector-conversions \
--copt=-fomit-frame-pointer \
//tensorflow/tools/pip_package:build_pip_package
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-1.14.0-cp37-cp37m-linux_aarch64.whl ~
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-1.14.0-cp37-cp37m-linux_aarch64.whl
Tensorflow v1.15.0

============================================================

Tensorflow v1.15.0 - Bazel 0.26.1 - Buster - armhf

============================================================ First, install openjdk-8-jdk according to the procedure of the following URL. [Stable] Install openjdk-8-jdk safely in Raspbian Buster (Debian 10) environment Next, follow the steps below to build Tensorflow on RaspberryPi3/4.

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev libatlas-base-dev libopenblas-dev
$ sudo pip3 install keras_applications==1.0.8 --no-deps
$ sudo pip3 install keras_preprocessing==1.1.0 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six numpy wheel mock

$ cd ~
$ git clone https://github.com/PINTO0309/Bazel_bin.git
$ cd Bazel_bin
$ ./0.26.1/Raspbian_Debian_Buster_armhf/openjdk-8-jdk/install.sh

$ cd ~
$ git clone -b v1.15.0 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v1.15.0
  • tensorflow/lite/python/interpreter.py
# Add the following two lines to the last line
  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.
  • tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false
  • configure
$ ./configure
Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.26.1- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /home/pi/inference_engine_vpu_arm/python/python3.5
  /usr/local/lib/python3.5/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib]
/usr/local/lib/python3.5/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
	--config=mkl         	# Build with MKL support.
	--config=monolithic  	# Config for mostly static monolithic build.
	--config=gdr         	# Build with GDR support.
	--config=verbs       	# Build with libverbs support.
	--config=ngraph      	# Build with Intel nGraph support.
	--config=numa        	# Build with NUMA support.
	--config=dynamic_kernels	# (Experimental) Build kernels into separate shared objects.
Preconfigured Bazel build configs to DISABLE default on features:
	--config=noaws       	# Disable AWS S3 filesystem support.
	--config=nogcp       	# Disable GCP support.
	--config=nohdfs      	# Disable HDFS support.
	--config=noignite    	# Disable Apache Ignite support.
	--config=nokafka     	# Disable Apache Kafka support.
	--config=nonccl      	# Disable NVIDIA NCCL support.
Configuration finished
  • build
$ sudo bazel build \
--config=opt \
--config=noaws \
--config=nogcp \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--local_resources=1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-1.15.0-cp37-cp37m-linux_arm7l.whl ~
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-1.15.0-cp37-cp37m-linux_armv7l.whl
Tensorflow v2.0.0-alpha

============================================================

Tensorflow v2.0.0-alpha - Stretch - Bazel 0.19.2

============================================================

$ sudo nano /etc/dphys-swapfile
CONF_SWAPFILE=2048
CONF_MAXSWAP=2048

$ sudo systemctl stop dphys-swapfile
$ sudo systemctl start dphys-swapfile

$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/zram.sh
$ chmod 755 zram.sh
$ sudo mv zram.sh /etc/init.d/
$ sudo update-rc.d zram.sh defaults
$ sudo reboot

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev libatlas-base-dev libopenblas-dev
$ sudo pip3 install keras_applications==1.0.7 --no-deps
$ sudo pip3 install keras_preprocessing==1.0.9 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six numpy wheel mock
$ sudo apt update;sudo apt upgrade

$ cd ~
$ git clone https://github.com/PINTO0309/Bazel_bin.git
$ cd Bazel_bin
$ ./0.19.2/Raspbian_armhf/install.sh

$ cd ~
$ git clone -b v2.0.0-alpha0 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v2.0.0-alpha0
  • tensorflow/lite/python/interpreter.py
# Add the following two lines to the last line
  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.
  • tensorflow/tensorflow/core/kernels/BUILD
cc_library(
    name = "linalg",
    deps = [
        ":cholesky_grad",
        ":cholesky_op",
        ":determinant_op",
        ":lu_op",
        ":matrix_exponential_op",
        ":matrix_inverse_op",
        ":matrix_logarithm_op",
        ":matrix_solve_ls_op",
        ":matrix_solve_op",
        ":matrix_triangular_solve_op",
        ":qr_op",
        ":self_adjoint_eig_op",
        ":self_adjoint_eig_v2_op",
        ":svd_op",
        ":tridiagonal_solve_op",
    ],
)
  • tensorflow/tensorflow/core/kernels/BUILD - Delete the following
tf_kernel_library(
    name = "matrix_square_root_op",
    prefix = "matrix_square_root_op",
    deps = LINALG_DEPS,
)
  • tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false
  • tensorflow/contrib/__init__.py
from tensorflow.contrib import checkpoint
#if os.name != "nt" and platform.machine() != "s390x":
#  from tensorflow.contrib import cloud
from tensorflow.contrib import cluster_resolver
  • tensorflow/contrib/__init__.py
from tensorflow.contrib.summary import summary

if os.name != "nt" and platform.machine() != "s390x":
  try:
    from tensorflow.contrib import cloud
  except ImportError:
    pass

from tensorflow.python.util.lazy_loader import LazyLoader
ffmpeg = LazyLoader("ffmpeg", globals(),
                    "tensorflow.contrib.ffmpeg")
  • configure
$ ./configure
Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.19.2- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /home/b920405/git/caffe-jacinto/python
  /opt/intel//computer_vision_sdk_2018.5.455/python/python3.5/ubuntu16
  /opt/intel//computer_vision_sdk_2018.5.455/python/python3.5
  .
  /opt/intel//computer_vision_sdk_2018.5.455/deployment_tools/model_optimizer
  /opt/movidius/caffe/python
  /usr/lib/python3/dist-packages
  /usr/local/lib/python3.5/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib]
/usr/local/lib/python3.5/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
    --config=mkl            # Build with MKL support.
    --config=monolithic     # Config for mostly static monolithic build.
    --config=gdr            # Build with GDR support.
    --config=verbs          # Build with libverbs support.
    --config=ngraph         # Build with Intel nGraph support.
    --config=dynamic_kernels    # (Experimental) Build kernels into separate shared objects.
Preconfigured Bazel build configs to DISABLE default on features:
    --config=noaws          # Disable AWS S3 filesystem support.
    --config=nogcp          # Disable GCP support.
    --config=nohdfs         # Disable HDFS support.
    --config=noignite       # Disable Apache Ignite support.
    --config=nokafka        # Disable Apache Kafka support.
    --config=nonccl         # Disable NVIDIA NCCL support.
Configuration finished
  • build
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nogcp \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--local_resources=1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.0.0a0-cp35-cp35m-linux_arm7l.whl ~
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-2.0.0a0-cp35-cp35m-linux_armv7l.whl
Tensorflow v2.0.0-beta0

============================================================

Tensorflow v2.0.0-beta0 - Stretch - Bazel 0.24.1

============================================================

$ sudo nano /etc/dphys-swapfile
CONF_SWAPFILE=2048
CONF_MAXSWAP=2048

$ sudo systemctl stop dphys-swapfile
$ sudo systemctl start dphys-swapfile

$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/zram.sh
$ chmod 755 zram.sh
$ sudo mv zram.sh /etc/init.d/
$ sudo update-rc.d zram.sh defaults
$ sudo reboot

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev libatlas-base-dev libopenblas-dev
$ sudo pip3 install keras_applications==1.0.7 --no-deps
$ sudo pip3 install keras_preprocessing==1.0.9 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six numpy wheel mock
$ sudo apt update;sudo apt upgrade

$ cd ~
$ git clone https://github.com/PINTO0309/Bazel_bin.git
$ cd Bazel_bin
$ ./0.24.1/Raspbian_Stretch_armhf/install.sh

$ cd ~
$ git clone -b v2.0.0-beta0 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v2.0.0-beta0
  • tensorflow/lite/python/interpreter.py
# Add the following two lines to the last line
  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.
  • tensorflow/tensorflow/core/kernels/BUILD
cc_library(
    name = "linalg",
    deps = [
        ":cholesky_grad",
        ":cholesky_op",
        ":determinant_op",
        ":lu_op",
        ":matrix_exponential_op",
        ":matrix_inverse_op",
        ":matrix_logarithm_op",
        ":matrix_solve_ls_op",
        ":matrix_solve_op",
        ":matrix_triangular_solve_op",
        ":qr_op",
        ":self_adjoint_eig_op",
        ":self_adjoint_eig_v2_op",
        ":svd_op",
        ":tridiagonal_solve_op",
    ],
)
  • tensorflow/tensorflow/core/kernels/BUILD - Delete the following
tf_kernel_library(
    name = "matrix_square_root_op",
    prefix = "matrix_square_root_op",
    deps = LINALG_DEPS,
)
  • tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false
  • tensorflow/contrib/__init__.py
from tensorflow.contrib import checkpoint
#if os.name != "nt" and platform.machine() != "s390x":
#  from tensorflow.contrib import cloud
from tensorflow.contrib import cluster_resolver
  • tensorflow/contrib/__init__.py
from tensorflow.contrib.summary import summary

if os.name != "nt" and platform.machine() != "s390x":
  try:
    from tensorflow.contrib import cloud
  except ImportError:
    pass

from tensorflow.python.util.lazy_loader import LazyLoader
ffmpeg = LazyLoader("ffmpeg", globals(),
                    "tensorflow.contrib.ffmpeg")
  • configure
$ ./configure
Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.24.1- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /home/pi/inference_engine_vpu_arm/python/python3.5
  /usr/local/lib/python3.5/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib]
/usr/local/lib/python3.5/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
	--config=mkl         	# Build with MKL support.
	--config=monolithic  	# Config for mostly static monolithic build.
	--config=gdr         	# Build with GDR support.
	--config=verbs       	# Build with libverbs support.
	--config=ngraph      	# Build with Intel nGraph support.
	--config=numa        	# Build with NUMA support.
	--config=dynamic_kernels	# (Experimental) Build kernels into separate shared objects.
Preconfigured Bazel build configs to DISABLE default on features:
	--config=noaws       	# Disable AWS S3 filesystem support.
	--config=nogcp       	# Disable GCP support.
	--config=nohdfs      	# Disable HDFS support.
	--config=noignite    	# Disable Apache Ignite support.
	--config=nokafka     	# Disable Apache Kafka support.
	--config=nonccl      	# Disable NVIDIA NCCL support.
Configuration finished
  • build
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nogcp \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--local_resources=1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.0.0b0-cp35-cp35m-linux_arm7l.whl ~
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-2.0.0b0-cp35-cp35m-linux_armv7l.whl
Tensorflow v2.0.0-beta1

============================================================

Tensorflow v2.0.0-beta1 - Stretch - Bazel 0.24.1

============================================================

$ sudo nano /etc/dphys-swapfile
CONF_SWAPFILE=2048
CONF_MAXSWAP=2048

$ sudo systemctl stop dphys-swapfile
$ sudo systemctl start dphys-swapfile

$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/zram.sh
$ chmod 755 zram.sh
$ sudo mv zram.sh /etc/init.d/
$ sudo update-rc.d zram.sh defaults
$ sudo reboot

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev libatlas-base-dev libopenblas-dev
$ sudo pip3 install keras_applications==1.0.7 --no-deps
$ sudo pip3 install keras_preprocessing==1.0.9 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six numpy wheel mock
$ sudo apt update;sudo apt upgrade

$ cd ~
$ git clone https://github.com/PINTO0309/Bazel_bin.git
$ cd Bazel_bin
$ ./0.24.1/Raspbian_Stretch_armhf/install.sh

$ cd ~
$ git clone -b v2.0.0-beta1 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v2.0.0-beta1
  • tensorflow/lite/python/interpreter.py
# Add the following two lines to the last line
  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.
  • tensorflow/tensorflow/core/kernels/BUILD
cc_library(
    name = "linalg",
    deps = [
        ":cholesky_grad",
        ":cholesky_op",
        ":determinant_op",
        ":lu_op",
        ":matrix_exponential_op",
        ":matrix_inverse_op",
        ":matrix_logarithm_op",
        ":matrix_solve_ls_op",
        ":matrix_solve_op",
        ":matrix_triangular_solve_op",
        ":qr_op",
        ":self_adjoint_eig_op",
        ":self_adjoint_eig_v2_op",
        ":svd_op",
        ":tridiagonal_solve_op",
    ],
)
  • tensorflow/tensorflow/core/kernels/BUILD - Delete the following
tf_kernel_library(
    name = "matrix_square_root_op",
    prefix = "matrix_square_root_op",
    deps = LINALG_DEPS,
)
  • tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false
  • tensorflow/contrib/__init__.py
from tensorflow.contrib import checkpoint
#if os.name != "nt" and platform.machine() != "s390x":
#  from tensorflow.contrib import cloud
from tensorflow.contrib import cluster_resolver
  • tensorflow/contrib/__init__.py
from tensorflow.contrib.summary import summary

if os.name != "nt" and platform.machine() != "s390x":
  try:
    from tensorflow.contrib import cloud
  except ImportError:
    pass

from tensorflow.python.util.lazy_loader import LazyLoader
ffmpeg = LazyLoader("ffmpeg", globals(),
                    "tensorflow.contrib.ffmpeg")
  • configure
$ ./configure
Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.24.1- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /home/pi/inference_engine_vpu_arm/python/python3.5
  /usr/local/lib/python3.5/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib]
/usr/local/lib/python3.5/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
	--config=mkl         	# Build with MKL support.
	--config=monolithic  	# Config for mostly static monolithic build.
	--config=gdr         	# Build with GDR support.
	--config=verbs       	# Build with libverbs support.
	--config=ngraph      	# Build with Intel nGraph support.
	--config=numa        	# Build with NUMA support.
	--config=dynamic_kernels	# (Experimental) Build kernels into separate shared objects.
Preconfigured Bazel build configs to DISABLE default on features:
	--config=noaws       	# Disable AWS S3 filesystem support.
	--config=nogcp       	# Disable GCP support.
	--config=nohdfs      	# Disable HDFS support.
	--config=noignite    	# Disable Apache Ignite support.
	--config=nokafka     	# Disable Apache Kafka support.
	--config=nonccl      	# Disable NVIDIA NCCL support.
Configuration finished
  • build
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nogcp \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--local_resources=1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.0.0b1-cp35-cp35m-linux_arm7l.whl ~
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-2.0.0b1-cp35-cp35m-linux_armv7l.whl
Tensorflow v2.0.0-rc0

============================================================

Tensorflow v2.0.0-rc0 - Buster - Bazel 0.26.1

============================================================ First, install openjdk-8-jdk according to the procedure of the following URL. How to install openjdk-8-jdk on Raspbian Buster armhf or How to install openjdk-8-jdk on Debian Buster (Debian 10) armhf Next, follow the steps below to build Tensorflow on RaspberryPi3.

$ sudo nano /etc/dphys-swapfile
CONF_SWAPFILE=2048
CONF_MAXSWAP=2048

$ sudo systemctl stop dphys-swapfile
$ sudo systemctl start dphys-swapfile

$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/zram.sh
$ chmod 755 zram.sh
$ sudo mv zram.sh /etc/init.d/
$ sudo update-rc.d zram.sh defaults
$ sudo reboot

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev libatlas-base-dev libopenblas-dev
$ sudo pip3 install keras_applications==1.0.8 --no-deps
$ sudo pip3 install keras_preprocessing==1.1.0 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six numpy wheel mock
$ sudo apt update;sudo apt upgrade

$ cd ~
$ git clone https://github.com/PINTO0309/Bazel_bin.git
$ cd Bazel_bin
$ ./0.26.1/Raspbian_Debian_Buster_armhf/openjdk-8-jdk/install.sh

$ cd ~
$ git clone -b v2.0.0-rc0 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v2.0.0-rc0
  • tensorflow/lite/python/interpreter.py
# Add the following two lines to the last line
  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.
  • tensorflow/tensorflow/core/kernels/BUILD
cc_library(
    name = "linalg",
    deps = [
        ":cholesky_grad",
        ":cholesky_op",
        ":determinant_op",
        ":lu_op",
        ":matrix_exponential_op",
        ":matrix_inverse_op",
        ":matrix_logarithm_op",
        ":matrix_solve_ls_op",
        ":matrix_solve_op",
        ":matrix_triangular_solve_op",
        ":qr_op",
        ":self_adjoint_eig_op",
        ":self_adjoint_eig_v2_op",
        ":svd_op",
        ":tridiagonal_solve_op",
    ],
)
  • tensorflow/tensorflow/core/kernels/BUILD - Delete the following
tf_kernel_library(
    name = "matrix_square_root_op",
    prefix = "matrix_square_root_op",
    deps = LINALG_DEPS,
)
  • tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false
  • configure
$ ./configure
Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.26.1- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /home/pi/inference_engine_vpu_arm/python/python3.7
  /usr/local/lib/python3.7/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib]
/usr/local/lib/python3.7/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
	--config=mkl         	# Build with MKL support.
	--config=monolithic  	# Config for mostly static monolithic build.
	--config=gdr         	# Build with GDR support.
	--config=verbs       	# Build with libverbs support.
	--config=ngraph      	# Build with Intel nGraph support.
	--config=numa        	# Build with NUMA support.
	--config=dynamic_kernels	# (Experimental) Build kernels into separate shared objects.
	--config=v2             # Build Tensorflow 2.x instead of 1.x
Preconfigured Bazel build configs to DISABLE default on features:
	--config=noaws       	# Disable AWS S3 filesystem support.
	--config=nogcp       	# Disable GCP support.
	--config=nohdfs      	# Disable HDFS support.
	--config=noignite    	# Disable Apache Ignite support.
	--config=nokafka     	# Disable Apache Kafka support.
	--config=nonccl      	# Disable NVIDIA NCCL support.
Configuration finished
  • build
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--config=v2 \
--local_resources=1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.0.0rc0-cp37-cp37m-linux_arm7l.whl ~
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-2.0.0rc0-cp37-cp37m-linux_armv7l.whl
Tensorflow v2.0.0-rc1

============================================================

Tensorflow v2.0.0-rc1 - Buster - Bazel 0.26.1

============================================================ First, install openjdk-8-jdk according to the procedure of the following URL. How to install openjdk-8-jdk on Raspbian Buster armhf or How to install openjdk-8-jdk on Debian Buster (Debian 10) armhf Next, follow the steps below to build Tensorflow on RaspberryPi3.

$ sudo nano /etc/dphys-swapfile
CONF_SWAPFILE=2048
CONF_MAXSWAP=2048

$ sudo systemctl stop dphys-swapfile
$ sudo systemctl start dphys-swapfile

$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/zram.sh
$ chmod 755 zram.sh
$ sudo mv zram.sh /etc/init.d/
$ sudo update-rc.d zram.sh defaults
$ sudo reboot

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev libatlas-base-dev libopenblas-dev
$ sudo pip3 install keras_applications==1.0.8 --no-deps
$ sudo pip3 install keras_preprocessing==1.1.0 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six numpy wheel mock
$ sudo apt update;sudo apt upgrade

$ cd ~
$ git clone https://github.com/PINTO0309/Bazel_bin.git
$ cd Bazel_bin
$ ./0.26.1/Raspbian_Debian_Buster_armhf/openjdk-8-jdk/install.sh

$ cd ~
$ git clone -b v2.0.0-rc1 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v2.0.0-rc1
  • tensorflow/lite/python/interpreter.py
# Add the following two lines to the last line
  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.
  • tensorflow/tensorflow/core/kernels/BUILD
cc_library(
    name = "linalg",
    deps = [
        ":cholesky_grad",
        ":cholesky_op",
        ":determinant_op",
        ":lu_op",
        ":matrix_exponential_op",
        ":matrix_inverse_op",
        ":matrix_logarithm_op",
        ":matrix_solve_ls_op",
        ":matrix_solve_op",
        ":matrix_triangular_solve_op",
        ":qr_op",
        ":self_adjoint_eig_op",
        ":self_adjoint_eig_v2_op",
        ":svd_op",
        ":tridiagonal_solve_op",
    ],
)
  • tensorflow/tensorflow/core/kernels/BUILD - Delete the following
tf_kernel_library(
    name = "matrix_square_root_op",
    prefix = "matrix_square_root_op",
    deps = LINALG_DEPS,
)
  • tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false
  • configure
$ ./configure
Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.26.1- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /home/pi/inference_engine_vpu_arm/python/python3.7
  /usr/local/lib/python3.7/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib]
/usr/local/lib/python3.7/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
	--config=mkl         	# Build with MKL support.
	--config=monolithic  	# Config for mostly static monolithic build.
	--config=gdr         	# Build with GDR support.
	--config=verbs       	# Build with libverbs support.
	--config=ngraph      	# Build with Intel nGraph support.
	--config=numa        	# Build with NUMA support.
	--config=dynamic_kernels	# (Experimental) Build kernels into separate shared objects.
	--config=v2             # Build Tensorflow 2.x instead of 1.x
Preconfigured Bazel build configs to DISABLE default on features:
	--config=noaws       	# Disable AWS S3 filesystem support.
	--config=nogcp       	# Disable GCP support.
	--config=nohdfs      	# Disable HDFS support.
	--config=noignite    	# Disable Apache Ignite support.
	--config=nokafka     	# Disable Apache Kafka support.
	--config=nonccl      	# Disable NVIDIA NCCL support.
Configuration finished
  • build
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--config=v2 \
--local_resources=1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.0.0rc1-cp37-cp37m-linux_arm7l.whl ~
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-2.0.0rc1-cp37-cp37m-linux_armv7l.whl
Tensorflow v2.0.0-rc2

============================================================

Tensorflow v2.0.0-rc2 - Buster - Bazel 0.26.1

============================================================ First, install openjdk-8-jdk according to the procedure of the following URL. How to install openjdk-8-jdk on Raspbian Buster armhf or How to install openjdk-8-jdk on Debian Buster (Debian 10) armhf Next, follow the steps below to build Tensorflow on RaspberryPi3.

$ sudo nano /etc/dphys-swapfile
CONF_SWAPFILE=2048
CONF_MAXSWAP=2048

$ sudo systemctl stop dphys-swapfile
$ sudo systemctl start dphys-swapfile

$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/zram.sh
$ chmod 755 zram.sh
$ sudo mv zram.sh /etc/init.d/
$ sudo update-rc.d zram.sh defaults
$ sudo reboot

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev libatlas-base-dev libopenblas-dev
$ sudo pip3 install keras_applications==1.0.8 --no-deps
$ sudo pip3 install keras_preprocessing==1.1.0 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six numpy wheel mock

$ cd ~
$ git clone https://github.com/PINTO0309/Bazel_bin.git
$ cd Bazel_bin
$ ./0.26.1/Raspbian_Debian_Buster_armhf/openjdk-8-jdk/install.sh

$ cd ~
$ git clone -b v2.0.0-rc2 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v2.0.0-rc2
  • tensorflow/lite/python/interpreter.py
# Add the following two lines to the last line
  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.
  • tensorflow/tensorflow/core/kernels/BUILD
cc_library(
    name = "linalg",
    deps = [
        ":cholesky_grad",
        ":cholesky_op",
        ":determinant_op",
        ":lu_op",
        ":matrix_exponential_op",
        ":matrix_inverse_op",
        ":matrix_logarithm_op",
        ":matrix_solve_ls_op",
        ":matrix_solve_op",
        ":matrix_triangular_solve_op",
        ":qr_op",
        ":self_adjoint_eig_op",
        ":self_adjoint_eig_v2_op",
        ":svd_op",
        ":tridiagonal_solve_op",
    ],
)
  • tensorflow/tensorflow/core/kernels/BUILD - Delete the following
tf_kernel_library(
    name = "matrix_square_root_op",
    prefix = "matrix_square_root_op",
    deps = LINALG_DEPS,
)
  • tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false
  • configure
$ ./configure
Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.26.1- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /home/pi/inference_engine_vpu_arm/python/python3.7
  /usr/local/lib/python3.7/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib]
/usr/local/lib/python3.7/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
	--config=mkl         	# Build with MKL support.
	--config=monolithic  	# Config for mostly static monolithic build.
	--config=gdr         	# Build with GDR support.
	--config=verbs       	# Build with libverbs support.
	--config=ngraph      	# Build with Intel nGraph support.
	--config=numa        	# Build with NUMA support.
	--config=dynamic_kernels	# (Experimental) Build kernels into separate shared objects.
	--config=v2             # Build Tensorflow 2.x instead of 1.x
Preconfigured Bazel build configs to DISABLE default on features:
	--config=noaws       	# Disable AWS S3 filesystem support.
	--config=nogcp       	# Disable GCP support.
	--config=nohdfs      	# Disable HDFS support.
	--config=noignite    	# Disable Apache Ignite support.
	--config=nokafka     	# Disable Apache Kafka support.
	--config=nonccl      	# Disable NVIDIA NCCL support.
Configuration finished
  • build
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--config=v2 \
--local_resources=1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.0.0rc2-cp37-cp37m-linux_arm7l.whl ~
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-2.0.0rc2-cp37-cp37m-linux_armv7l.whl
Tensorflow v2.0.0

============================================================

Tensorflow v2.0.0 - Stretch - Bazel 0.26.1

============================================================

$ sudo nano /etc/dphys-swapfile
CONF_SWAPFILE=2048
CONF_MAXSWAP=2048

$ sudo systemctl stop dphys-swapfile
$ sudo systemctl start dphys-swapfile

$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/zram.sh
$ chmod 755 zram.sh
$ sudo mv zram.sh /etc/init.d/
$ sudo update-rc.d zram.sh defaults
$ sudo reboot

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev libatlas-base-dev libopenblas-dev openjdk-8-jdk
$ sudo pip3 install keras_applications==1.0.8 --no-deps
$ sudo pip3 install keras_preprocessing==1.1.0 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six numpy wheel mock

$ cd ~
$ git clone https://github.com/PINTO0309/Bazel_bin.git
$ cd Bazel_bin
$ ./0.26.1/Raspbian_Debian_Buster_armhf/openjdk-8-jdk/install.sh

$ cd ~
$ git clone -b v2.0.0 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v2.0.0
  • tensorflow/lite/python/interpreter.py
# Add the following two lines to the last line
  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.
  • tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false
  • tensorflow/lite/experimental/ruy/pack_arm.cc - Line 1292
"mov r0, 0\n""mov r0, #0\n"
  • configure
$ sudo ./configure
Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.26.1- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /home/pi/inference_engine_vpu_arm/python/python3.7
  /usr/local/lib/python3.7/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib]
/usr/local/lib/python3.7/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
	--config=mkl         	# Build with MKL support.
	--config=monolithic  	# Config for mostly static monolithic build.
	--config=gdr         	# Build with GDR support.
	--config=verbs       	# Build with libverbs support.
	--config=ngraph      	# Build with Intel nGraph support.
	--config=numa        	# Build with NUMA support.
	--config=dynamic_kernels	# (Experimental) Build kernels into separate shared objects.
	--config=v2             # Build Tensorflow 2.x instead of 1.x
Preconfigured Bazel build configs to DISABLE default on features:
	--config=noaws       	# Disable AWS S3 filesystem support.
	--config=nogcp       	# Disable GCP support.
	--config=nohdfs      	# Disable HDFS support.
	--config=noignite    	# Disable Apache Ignite support.
	--config=nokafka     	# Disable Apache Kafka support.
	--config=nonccl      	# Disable NVIDIA NCCL support.
Configuration finished
  • build
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--config=v2 \
--local_resources=1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.0.0-cp35-cp35m-linux_arm7l.whl ~
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-2.0.0-cp35-cp35m-linux_armv7l.whl

============================================================

Tensorflow v2.0.0 - Buster - Bazel 0.26.1

============================================================ First, install openjdk-8-jdk according to the procedure of the following URL. [Stable] Install openjdk-8-jdk safely in Raspbian Buster (Debian 10) environment Next, follow the steps below to build Tensorflow on RaspberryPi3/4.

$ sudo nano /etc/dphys-swapfile
CONF_SWAPFILE=2048
CONF_MAXSWAP=2048

$ sudo systemctl stop dphys-swapfile
$ sudo systemctl start dphys-swapfile

$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/zram.sh
$ chmod 755 zram.sh
$ sudo mv zram.sh /etc/init.d/
$ sudo update-rc.d zram.sh defaults
$ sudo reboot

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev libatlas-base-dev libopenblas-dev
$ sudo pip3 install keras_applications==1.0.8 --no-deps
$ sudo pip3 install keras_preprocessing==1.1.0 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six numpy wheel mock

$ cd ~
$ git clone https://github.com/PINTO0309/Bazel_bin.git
$ cd Bazel_bin
$ ./0.26.1/Raspbian_Debian_Buster_armhf/openjdk-8-jdk/install.sh

$ cd ~
$ git clone -b v2.0.0 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v2.0.0
  • tensorflow/lite/python/interpreter.py
# Add the following two lines to the last line
  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
  • tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.
  • tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false
  • tensorflow/lite/experimental/ruy/pack_arm.cc - Line 1292
"mov r0, 0\n""mov r0, #0\n"
  • configure
$ sudo ./configure
Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.26.1- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /home/pi/inference_engine_vpu_arm/python/python3.7
  /usr/local/lib/python3.7/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib]
/usr/local/lib/python3.7/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]:


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
	--config=mkl         	# Build with MKL support.
	--config=monolithic  	# Config for mostly static monolithic build.
	--config=gdr         	# Build with GDR support.
	--config=verbs       	# Build with libverbs support.
	--config=ngraph      	# Build with Intel nGraph support.
	--config=numa        	# Build with NUMA support.
	--config=dynamic_kernels	# (Experimental) Build kernels into separate shared objects.
	--config=v2             # Build Tensorflow 2.x instead of 1.x
Preconfigured Bazel build configs to DISABLE default on features:
	--config=noaws       	# Disable AWS S3 filesystem support.
	--config=nogcp       	# Disable GCP support.
	--config=nohdfs      	# Disable HDFS support.
	--config=noignite    	# Disable Apache Ignite support.
	--config=nokafka     	# Disable Apache Kafka support.
	--config=nonccl      	# Disable NVIDIA NCCL support.
Configuration finished
  • build

(1) RaspberryPi3

$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--config=v2 \
--local_resources=1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package

(2) RaspberryPi4

$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--config=v2 \
--local_resources=4096.0,3.0,1.0 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.0.0-cp37-cp37m-linux_arm7l.whl ~
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-2.0.0-cp37-cp37m-linux_armv7l.whl
Tensorflow v2.1.0-rc0

============================================================

Tensorflow v2.1.0-rc0 - Buster - Bazel 0.29.1

============================================================

$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--local_resources=4096.0,3.0,1.0 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
Tensorflow v2.1.0-rc1

============================================================

Tensorflow v2.1.0-rc1 - Buster - Bazel 0.29.1

============================================================

$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--local_resources=4096.0,3.0,1.0 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
Tensorflow v2.1.0-rc2

============================================================

Tensorflow v2.1.0-rc2 - Buster - Bazel 0.29.1

============================================================

$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--local_resources=4096.0,3.0,1.0 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
Tensorflow v2.1.0

============================================================

Tensorflow v2.1.0 - Ubuntu 19.10 aarch64 - Bazel 0.29.1

============================================================

Update grpc dependency for glibc 2.30 compatibility

$ curl -L https://github.com/tensorflow/tensorflow/compare/master...hi-ogawa:grpc-backport-pr-18950.patch | git apply

$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--local_resources=4096.0,3.0,1.0 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

Tensorflow v2.1.0 - Buster - Bazel 0.29.1

============================================================

$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--local_resources=4096.0,3.0,1.0 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package
Tensorflow v2.2.0

============================================================

Tensorflow v2.2.0 - Buster - Bazel 2.0.0

============================================================

$ sudo nano .tf_configure.bazelrc

build --action_env PYTHON_BIN_PATH="/usr/bin/python3"
build --action_env PYTHON_LIB_PATH="/usr/local/lib/python3.7/dist-packages"
build --python_path="/usr/bin/python3"
build --config=xla
build:opt --copt=-march=native
build:opt --copt=-Wno-sign-compare
build:opt --host_copt=-march=native
build:opt --define with_default_optimizations=true
test --flaky_test_attempts=3
test --test_size_filters=small,medium
test:v1 --test_tag_filters=-benchmark-test,-no_oss,-gpu,-oss_serial
test:v1 --build_tag_filters=-benchmark-test,-no_oss,-gpu
test:v2 --test_tag_filters=-benchmark-test,-no_oss,-gpu,-oss_serial,-v1only
test:v2 --build_tag_filters=-benchmark-test,-no_oss,-gpu,-v1only
build --action_env TF_CONFIGURE_IOS="0"

↓

build --action_env PYTHON_BIN_PATH="/usr/bin/python3"
build --action_env PYTHON_LIB_PATH="/usr/local/lib/python3.7/dist-packages"
build --python_path="/usr/bin/python3"
build:opt --copt=-march=native
build:opt --copt=-Wno-sign-compare
build:opt --host_copt=-march=native
build:opt --define with_default_optimizations=true
test --flaky_test_attempts=3
test --test_size_filters=small,medium
test:v1 --test_tag_filters=-benchmark-test,-no_oss,-gpu,-oss_serial
test:v1 --build_tag_filters=-benchmark-test,-no_oss,-gpu
test:v2 --test_tag_filters=-benchmark-test,-no_oss,-gpu,-oss_serial,-v1only
test:v2 --build_tag_filters=-benchmark-test,-no_oss,-gpu,-v1only
build --action_env TF_CONFIGURE_IOS="0"
build --action_env TF_ENABLE_XLA="0"
build --define with_xla_support=false
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--local_resources=4096.0,2.0,1.0 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package

============================================================

Tensorflow v2.2.0 - Ubuntu 19.10 aarch64 - Bazel 2.0.0

============================================================

$ nano tensorflow/third_party/py/python_configure.bzl

def _get_python_include(repository_ctx, python_bin):
    """Gets the python include path."""
    result = execute(
        repository_ctx,
        [
            python_bin,

↓

def _get_python_include(repository_ctx, python_bin):
    """Gets the python include path."""
    result = execute(
        repository_ctx,
        [
            "python3",

$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--local_resources=4096.0,3.0,1.0 \
//tensorflow/tools/pip_package:build_pip_package
Tensorflow v2.3.0

============================================================

Tensorflow v2.3.0-rc0 - Buster - Bazel 3.1.0

============================================================

$ sudo nano .tf_configure.bazelrc

build --action_env PYTHON_BIN_PATH="/usr/bin/python3"
build --action_env PYTHON_LIB_PATH="/usr/local/lib/python3.7/dist-packages"
build --python_path="/usr/bin/python3"
build --config=xla
build:opt --copt=-march=native
build:opt --copt=-Wno-sign-compare
build:opt --host_copt=-march=native
build:opt --define with_default_optimizations=true
test --flaky_test_attempts=3
test --test_size_filters=small,medium
test:v1 --test_tag_filters=-benchmark-test,-no_oss,-gpu,-oss_serial
test:v1 --build_tag_filters=-benchmark-test,-no_oss,-gpu
test:v2 --test_tag_filters=-benchmark-test,-no_oss,-gpu,-oss_serial,-v1only
test:v2 --build_tag_filters=-benchmark-test,-no_oss,-gpu,-v1only
build --action_env TF_CONFIGURE_IOS="0"

↓

build --action_env PYTHON_BIN_PATH="/usr/bin/python3"
build --action_env PYTHON_LIB_PATH="/usr/local/lib/python3.7/dist-packages"
build --python_path="/usr/bin/python3"
build:opt --copt=-march=native
build:opt --copt=-Wno-sign-compare
build:opt --host_copt=-march=native
build:opt --define with_default_optimizations=true
test --flaky_test_attempts=3
test --test_size_filters=small,medium
test:v1 --test_tag_filters=-benchmark-test,-no_oss,-gpu,-oss_serial
test:v1 --build_tag_filters=-benchmark-test,-no_oss,-gpu
test:v2 --test_tag_filters=-benchmark-test,-no_oss,-gpu,-oss_serial,-v1only
test:v2 --build_tag_filters=-benchmark-test,-no_oss,-gpu,-v1only
build --action_env TF_CONFIGURE_IOS="0"
build --action_env TF_ENABLE_XLA="0"
build --define with_xla_support=false
$ wget https://gitlab.com/libeigen/eigen/-/archive/386d809bde475c65b7940f290efe80e6a05878c4/eigen-386d809bde475c65b7940f290efe80e6a05878c4.tar.gz
$ nano tensorflow/workspace.bzl

    tf_http_archive(
        name = "eigen_archive",
        build_file = clean_dep("//third_party:eigen.BUILD"),
        patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"),
        sha256 = "f632d82e43ffc46adfac9043beace700b0265748075e7edc0701d81380258038",  # SHARED_EIGEN_SHA
        strip_prefix = "eigen-386d809bde475c65b7940f290efe80e6a05878c4",
        urls = [
            "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/386d809bde475c65b7940f290efe80e6a05878c4/eigen-386d809bde475c65b7940f290efe80e6a05878c4.tar.gz",
            "https://gitlab.com/libeigen/eigen/-/archive/386d809bde475c65b7940f290efe80e6a05878c4/eigen-386d809bde475c65b7940f290efe80e6a05878c4.tar.gz",
        ],
    )
↓
    tf_http_archive(
        name = "eigen_archive",
        build_file = clean_dep("//third_party:eigen.BUILD"),
        patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"),
        sha256 = "f632d82e43ffc46adfac9043beace700b0265748075e7edc0701d81380258038",  # SHARED_EIGEN_SHA
        strip_prefix = "eigen-386d809bde475c65b7940f290efe80e6a05878c4",
        urls = [
	    "file:///home/pi/tensorflow/eigen-386d809bde475c65b7940f290efe80e6a05878c4.tar.gz",
            "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/386d809bde475c65b7940f290efe80e6a05878c4/eigen-386d809bde475c65b7940f290efe80e6a05878c4.tar.gz",
            "https://gitlab.com/libeigen/eigen/-/archive/386d809bde475c65b7940f290efe80e6a05878c4/eigen-386d809bde475c65b7940f290efe80e6a05878c4.tar.gz",
        ],
    )
$ nano tensorflow/third_party/ruy/workspace.bzl

def repo():
    third_party_http_archive(
        name = "ruy",
        sha256 = "8fd4adeeff4f29796bf7cdda64806ec0495a2435361569f02afe3fe33406f07c",
        strip_prefix = "ruy-34ea9f4993955fa1ff4eb58e504421806b7f2e8f",
        urls = [
            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/ruy/archive/34ea9f4993955fa1ff4eb58e504421806b7f2e8f.zip",
            "https://github.com/google/ruy/archive/34ea9f4993955fa1ff4eb58e504421806b7f2e8f.zip",
        ],
        build_file = "//third_party/ruy:BUILD",
    )

↓

def repo():
    third_party_http_archive(
        name = "ruy",
        sha256 = "89b8b56b4e1db894e75a0abed8f69757b37c23dde6e64bfb186656197771138a",
        strip_prefix = "ruy-388ffd28ba00ffb9aacbe538225165c02ea33ee3",
        urls = [
            "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/ruy/archive/388ffd28ba00ffb9aacbe538225165c02ea33ee3.zip",
            "https://github.com/google/ruy/archive/388ffd28ba00ffb9aacbe538225165c02ea33ee3.zip",
        ],
        build_file = "//third_party/ruy:BUILD",
    )
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--local_ram_resources=4096 \
--local_cpu_resources=2 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
--linkopt=-Wl,-latomic \
--host_linkopt=-Wl,-latomic \
--define=tensorflow_mkldnn_contraction_kernel=0 \
--define=raspberry_pi_with_neon=true \
--define=tflite_pip_with_flex=true \
//tensorflow/tools/pip_package:build_pip_package

============================================================

Tensorflow v2.3.0-rc0 - Debian Buster aarch64 - Bazel 3.1.0

============================================================

$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--local_ram_resources=30720 \
--local_cpu_resources=10 \
//tensorflow/tools/pip_package:build_pip_package
Tensorflow v2.4.0
  • tensorflow/tensorflow/lite/kernels/BUILD
cc_library(
    name = "builtin_op_kernels",
    srcs = BUILTIN_KERNEL_SRCS + [
        "max_pool_argmax.cc",
        "max_unpooling.cc",
        "transpose_conv_bias.cc",
    ],
    hdrs = [
        "dequantize.h",
        "max_pool_argmax.h",
        "max_unpooling.h",
        "transpose_conv_bias.h",
    ],
    compatible_with = get_compatible_with_portable(),
    copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
    visibility = ["//visibility:private"],
    deps = BUILTIN_KERNEL_DEPS + [
        "@ruy//ruy/profiler:instrumentation",
        "//tensorflow/lite/kernels/internal:cppmath",
        "//tensorflow/lite:string",
        "@farmhash_archive//:farmhash",
    ],
)
$ sudo pip3 install gdown
$ cd tensorflow/tensorflow/lite/kernels
$ sudo gdown --id 17qEXPvo5l72j4O5qEcSoLcmJAthaqSws
$ tar -zxvf kernels.tar.gz && rm kernels.tar.gz -f
$ cd ../../..

============================================================

Tensorflow v2.4.0 - Buster - Bazel 3.1.0

============================================================

$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--local_ram_resources=4096 \
--local_cpu_resources=2 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
--linkopt=-Wl,-latomic \
--host_linkopt=-Wl,-latomic \
--define=tensorflow_mkldnn_contraction_kernel=0 \
--define=raspberry_pi_with_neon=true \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
//tensorflow/tools/pip_package:build_pip_package

============================================================

Tensorflow v2.4.0 - Debian Buster aarch64 - Bazel 3.1.0

============================================================

$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--local_ram_resources=30720 \
--local_cpu_resources=10 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.4.0-cp37-cp37m-linux_arm7l.whl ~
Tensorflow v2.5.0
  • tensorflow/tensorflow/lite/kernels/BUILD Add a custom kernel for MediaPipe.
cc_library(
    name = "builtin_op_kernels",
    srcs = BUILTIN_KERNEL_SRCS + [
        "max_pool_argmax.cc",
        "max_unpooling.cc",
        "transpose_conv_bias.cc",
    ],
    hdrs = [
        "dequantize.h",
        "max_pool_argmax.h",
        "max_unpooling.h",
        "transpose_conv_bias.h",
    ],
    compatible_with = get_compatible_with_portable(),
    copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
    visibility = ["//visibility:private"],
    deps = BUILTIN_KERNEL_DEPS + [
        "@ruy//ruy/profiler:instrumentation",
        "//tensorflow/lite/kernels/internal:cppmath",
        "//tensorflow/lite:string",
        "@farmhash_archive//:farmhash",
    ],
)
$ sudo pip3 install gdown h5py==3.1.0
$ cd tensorflow/lite/kernels
$ sudo gdown --id 1fuB2m7B_-3u7-kxuNcALUp9wkrHsfCQB
$ tar -zxvf kernels.tar.gz && rm kernels.tar.gz -f
$ cd ../../..
$ sudo bazel clean --expunge

============================================================

Tensorflow v2.5.0 - Buster armv7l/armhf - Bazel 3.7.2 Native Build

============================================================

$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--local_ram_resources=4096 \
--local_cpu_resources=2 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
--linkopt=-Wl,-latomic \
--host_linkopt=-Wl,-latomic \
--define=tensorflow_mkldnn_contraction_kernel=0 \
--define=raspberry_pi_with_neon=true \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
//tensorflow/tools/pip_package:build_pip_package

============================================================

Tensorflow v2.5.0 - Buster armv7l/armhf - Bazel 3.7.2 Cross-compilation by x86 host

============================================================

$ git clone https://github.com/PINTO0309/tensorflow-on-arm.git && \
  cd tensorflow-on-arm/build_tensorflow
$ docker build -t tf-arm -f Dockerfile .
$ docker run -it --rm \
  -v /tmp/tensorflow_pkg/:/tmp/tensorflow_pkg/ \
  --env TF_PYTHON_VERSION=3.7 \
  tf-arm ./build_tensorflow.sh configs/rpi.conf
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.5.0-cp37-none-linux_armv7l.whl .
$ sudo chmod 777 tensorflow-2.5.0-cp37-none-linux_armv7l.whl

============================================================

Tensorflow v2.5.0 - Debian Buster aarch64 - Bazel 3.7.2 Using EC2 m6g.16xlarge

============================================================

$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--ui_actions_shown=64 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.5.0-cp37-cp37m-linux_arm7l.whl ~
Tensorflow v2.6.0
$ sudo apt update && sudo apt upgrade -y && \
sudo apt install libhdf5-dev && \
sudo pip3 install pip --upgrade && \
sudo pip3 install keras_applications==1.0.8 --no-deps && \
sudo pip3 install keras_preprocessing==1.1.0 --no-deps && \
sudo pip3 install gdown h5py==3.1.0 && \
sudo pip3 install pybind11 && \
pip3 install -U --user six wheel mock
  • Apply customization to add custom operations for MediaPipe. (max_pool_argmax, max_unpooling, transpose_conv_bias)
cd tensorflow/lite/kernels
sudo gdown --id 124YrrMZjj_lZxVnpxePs-F69i0xz7Qru
tar -zxvf kernels.tar.gz && rm kernels.tar.gz -f
cd ../../..
  • Apply multi-threading support for XNNPACK.
# interpreter.py
cd tensorflow/lite/python
sudo gdown --id 1LuEW11VLhR4gO1RPlymELDvXBFqU7WSK
cd ../../..
# interpreter_wrapper.cc, interpreter_wrapper.h, interpreter_wrapper_pybind11.cc
cd tensorflow/lite/python/interpreter_wrapper
sudo gdown --id 1zTO0z6Pe_a6RJxw7N_3gyqhFxGunFK-y
tar -zxvf interpreter_wrapper.tar.gz && rm interpreter_wrapper.tar.gz -f
cd ../../../..

============================================================

Tensorflow v2.6.0 - Buster armv7l/armhf - Bazel 3.7.2 Native Build

============================================================

$ sudo bazel clean --expunge
$ ./configure
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--local_ram_resources=4096 \
--local_cpu_resources=2 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
--linkopt=-Wl,-latomic \
--host_linkopt=-Wl,-latomic \
--define=tensorflow_mkldnn_contraction_kernel=0 \
--define=raspberry_pi_with_neon=true \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
//tensorflow/tools/pip_package:build_pip_package

============================================================

Tensorflow v2.6.0 - Buster armv7l/armhf - Bazel 3.7.2 Cross-compilation by x86 host

============================================================

$ git clone https://github.com/PINTO0309/tensorflow-on-arm.git && \
  cd tensorflow-on-arm/build_tensorflow
$ docker build -t tf-arm -f Dockerfile .
$ docker run -it --rm \
  -v /tmp/tensorflow_pkg/:/tmp/tensorflow_pkg/ \
  --env TF_PYTHON_VERSION=3.7 \
  tf-arm ./build_tensorflow.sh configs/rpi.conf
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.6.0-cp37-none-linux_armv7l.whl .
$ sudo chmod 777 tensorflow-2.6.0-cp37-none-linux_armv7l.whl

============================================================

Tensorflow v2.6.0 - Debian Buster aarch64 - Bazel 3.7.2 Using EC2 m6g.16xlarge

============================================================

$ sudo bazel clean --expunge
$ ./configure
$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--ui_actions_shown=64 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

Tensorflow v2.6.0 - CUDA x86_64 - Bazel 3.7.2

============================================================

supports compute capabilities >= 3.5 [Default is: 3.5,7.0]: 5.3,6.1,6.2,7.2,7.5,8.6

$ sudo bazel clean --expunge
$ ./configure
$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--ui_actions_shown=20 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.6.0-cp37-cp37m-linux_arm7l.whl ~
Tensorflow v2.7.0
$ sudo apt update && sudo apt upgrade -y && \
sudo apt install libhdf5-dev && \
sudo pip3 install pip --upgrade && \
sudo pip3 install keras_applications==1.0.8 --no-deps && \
sudo pip3 install keras_preprocessing==1.1.2 --no-deps && \
sudo pip3 install gdown h5py==3.1.0 && \
sudo pip3 install pybind11 && \
pip3 install -U --user six wheel mock
  • Apply customization to add custom operations for MediaPipe. (max_pool_argmax, max_unpooling, transpose_conv_bias)
cd tensorflow/lite/kernels
sudo gdown --id 1Az4hEvLXAb71e52gBORQz87Z0FExUz2B
tar -zxvf kernels.tar.gz && rm kernels.tar.gz -f
cd ../../..
  • Apply multi-threading support for XNNPACK (Python).
# interpreter_wrapper.cc
sudo gdown --id 1iNc8qC1y5CJdMWCcTXhl6SiDQg3M1DRv
git apply xnnpack_python.patch

============================================================

Tensorflow v2.7.0 - Buster armv7l/armhf - Bazel 3.7.2 Native Build

============================================================

$ sudo bazel clean --expunge
$ ./configure
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--local_ram_resources=4096 \
--local_cpu_resources=2 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
--linkopt=-Wl,-latomic \
--host_linkopt=-Wl,-latomic \
--define=tensorflow_mkldnn_contraction_kernel=0 \
--define=raspberry_pi_with_neon=true \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
//tensorflow/tools/pip_package:build_pip_package

============================================================

Tensorflow v2.7.0 - Buster armv7l/armhf - Bazel 3.7.2 Cross-compilation by x86 host

============================================================

$ git clone https://github.com/PINTO0309/tensorflow-on-arm.git && \
  cd tensorflow-on-arm/build_tensorflow
$ docker build -t tf-arm -f Dockerfile .
$ docker run -it --rm \
  -v /tmp/tensorflow_pkg/:/tmp/tensorflow_pkg/ \
  --env TF_PYTHON_VERSION=3.7 \
  tf-arm ./build_tensorflow.sh configs/rpi.conf
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.7.0-cp37-none-linux_armv7l.whl .
$ sudo chmod 777 tensorflow-2.7.0-cp37-none-linux_armv7l.whl

============================================================

Tensorflow v2.7.0 - Debian Buster aarch64 - Bazel 3.7.2 Using EC2 m6g.16xlarge

============================================================

$ sudo bazel clean --expunge
$ ./configure
$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--ui_actions_shown=64 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

Tensorflow v2.7.0 - CUDA11.4 - TensorRT8.2 - x86_64 - Bazel 3.7.2

============================================================

$ sudo bazel clean --expunge
$ cp tensorflow/compiler/tf2tensorrt/stub/NvInfer_8_0.inc tensorflow/compiler/tf2tensorrt/stub/NvInfer_8_2.inc \
&& sed -i '62a #elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 2' tensorflow/compiler/tf2tensorrt/stub/nvinfer_stub.cc \
&& sed -i '63a #include "tensorflow/compiler/tf2tensorrt/stub/NvInfer_8_2.inc"' tensorflow/compiler/tf2tensorrt/stub/nvinfer_stub.cc \
&& cp tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_8_0.inc tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_8_2.inc \
&& sed -i '62a #elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 2' tensorflow/compiler/tf2tensorrt/stub/nvinfer_plugin_stub.cc \
&& sed -i '63a #include "tensorflow/compiler/tf2tensorrt/stub/NvInferPlugin_8_2.inc"' tensorflow/compiler/tf2tensorrt/stub/nvinfer_plugin_stub.cc

$ ./configure

supports compute capabilities >= 3.5 [Default is: 3.5,7.0]: 6.1,7.5,8.6

$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--ui_actions_shown=20 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

$ sudo ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.7.0*.whl ~
Tensorflow v2.8.0
$ sudo apt update && sudo apt upgrade -y && \
sudo apt install libhdf5-dev && \
sudo pip3 install pip --upgrade && \
sudo pip3 install keras_applications==1.0.8 --no-deps && \
sudo pip3 install keras_preprocessing==1.1.2 --no-deps && \
sudo pip3 install gdown h5py==3.1.0 && \
sudo pip3 install pybind11 && \
pip3 install -U --user six wheel mock
  • Apply customization to add custom operations for MediaPipe. (max_pool_argmax, max_unpooling, transpose_conv_bias)
cd tensorflow/lite/kernels
sudo gdown --id 1qTVQ9qnbvzxxWm-1mGGkO7NRB9Rd_Uht
tar -zxvf kernels.tar.gz && rm kernels.tar.gz -f
cd ../../..

============================================================

Tensorflow v2.8.0 - Debian Bullseye aarch64 - Bazel 4.2.1 Using EC2 m6g.16xlarge

============================================================

$ sudo bazel clean --expunge
$ ./configure
$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--ui_actions_shown=64 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

Tensorflow v2.8.0 - CUDA11.4 - TensorRT8.2 - x86_64 - Bazel 4.2.1

============================================================

$ wget https://github.com/bazelbuild/bazel/releases/download/4.2.1/bazel-4.2.1-installer-linux-x86_64.sh
$ sudo chmod +x bazel-4.2.1-installer-linux-x86_64.sh && sudo ./bazel-4.2.1-installer-linux-x86_64.sh
$ sudo bazel clean --expunge
$ ./configure

supports compute capabilities >= 3.5 [Default is: 3.5,7.0]: 6.1,7.5,8.6

$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--ui_actions_shown=20 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

$ sudo ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.8.0*.whl ~
Tensorflow v2.9.0
$ sudo apt update && sudo apt upgrade -y && \
sudo apt install -y libhdf5-dev unzip pkg-config python3-pip cmake make python-is-python3 && \
sudo pip3 install pip --upgrade && \
sudo pip3 install keras_applications==1.0.8 --no-deps && \
sudo pip3 install keras_preprocessing==1.1.2 --no-deps && \
sudo pip3 install gdown h5py==3.6.0 && \
sudo pip3 install pybind11==2.9.2 && \
sudo pip3 install packaging && \
pip3 install -U --user six wheel mock
  • Apply customization to add custom operations for MediaPipe. (max_pool_argmax, max_unpooling, transpose_conv_bias, TransformLandmarks, TransformTensorBilinear, Landmarks2TransformMatrix)
$ curl -OL https://github.com/PINTO0309/TensorflowLite-bin/releases/download/v2.9.0/mediapipe_customop_patch.zip \
&& unzip -d mediapipe_customop_patch mediapipe_customop_patch.zip \
&& git apply mediapipe_customop_patch/*

============================================================

Tensorflow v2.9.0 - Debian Bullseye aarch64 - Bazel 5.0.0 Using EC2 m6g.16xlarge

============================================================

$ wget -O bazel https://github.com/bazelbuild/bazel/releases/download/5.0.0/bazel-5.0.0-linux-arm64 \
&& sudo chmod 777 bazel \
&& sudo cp bazel /usr/local/bin \
&& sudo bazel clean --expunge \
&& ./configure

$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--ui_actions_shown=64 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

Tensorflow v2.9.0 - CUDA11.6 - TensorRT8.4 - x86_64 - Bazel 5.0.0

============================================================

$ wget https://github.com/bazelbuild/bazel/releases/download/5.0.0/bazel-5.0.0-installer-linux-x86_64.sh \
&& sudo chmod +x bazel-5.0.0-installer-linux-x86_64.sh \
&& sudo ./bazel-5.0.0-installer-linux-x86_64.sh \
&& sudo bazel clean --expunge
&& ./configure

supports compute capabilities >= 3.5 [Default is: 3.5,7.0]: 6.1,7.5,8.6

$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--ui_actions_shown=20 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

$ sudo ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.9.0*.whl ~
Tensorflow v2.10.0
$ sudo apt update && sudo apt upgrade -y && \
sudo apt install -y \
  libhdf5-dev unzip pkg-config python3-pip \
  cmake make python-is-python3 && \
sudo pip3 install pip --upgrade && \
sudo pip3 install numpy==1.23.2 && \
sudo pip3 install keras_applications==1.0.8 --no-deps && \
sudo pip3 install keras_preprocessing==1.1.2 --no-deps && \
sudo pip3 install gdown h5py==3.6.0 && \
sudo pip3 install pybind11==2.9.2 && \
sudo pip3 install packaging && \
pip3 install -U --user six wheel mock
  • Apply customization to add custom operations for MediaPipe. (max_pool_argmax, max_unpooling, transpose_conv_bias, TransformLandmarks, TransformTensorBilinear, Landmarks2TransformMatrix)
$ curl -OL https://github.com/PINTO0309/TensorflowLite-bin/releases/download/v2.10.0/mediapipe_customop_patch.zip \
&& unzip -d mediapipe_customop_patch mediapipe_customop_patch.zip \
&& git apply mediapipe_customop_patch/*

============================================================

Tensorflow v2.10.0 - Debian Bullseye aarch64 - Bazel 5.1.1 Using EC2 m6g.16xlarge

============================================================

$ wget -O bazel https://github.com/bazelbuild/bazel/releases/download/5.1.1/bazel-5.1.1-linux-arm64 \
&& sudo chmod 777 bazel \
&& sudo cp bazel /usr/local/bin \
&& sudo bazel clean --expunge \
&& ./configure

$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--ui_actions_shown=64 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

Tensorflow v2.10.0 - CUDA11.7 - TensorRT8.4.3 - x86_64 - Bazel 5.1.1

============================================================

$ wget https://github.com/bazelbuild/bazel/releases/download/5.1.1/bazel-5.1.1-installer-linux-x86_64.sh \
&& sudo chmod +x bazel-5.1.1-installer-linux-x86_64.sh \
&& sudo ./bazel-5.1.1-installer-linux-x86_64.sh \
&& sudo bazel clean --expunge
&& ./configure

supports compute capabilities >= 3.5 [Default is: 3.5,7.0]: 6.1,7.5,8.6

$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--ui_actions_shown=20 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

$ sudo ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.10.0*.whl ~
<INVALID> Tensorflow v2.11.0
$ sudo apt update && sudo apt upgrade -y && \
sudo apt install -y \
  libhdf5-dev unzip pkg-config python3-pip \
  cmake make python-is-python3 && \
sudo pip3 install pip --upgrade && \
sudo pip3 install numpy==1.23.4 && \
sudo pip3 install keras_applications==1.0.8 --no-deps && \
sudo pip3 install keras_preprocessing==1.1.2 --no-deps && \
sudo pip3 install gdown h5py==3.6.0 && \
sudo pip3 install pybind11==2.9.2 && \
sudo pip3 install packaging && \
pip3 install -U --user six wheel mock
  • Apply customization to add custom operations for MediaPipe. (max_pool_argmax, max_unpooling, transpose_conv_bias, TransformLandmarks, TransformTensorBilinear, Landmarks2TransformMatrix)
$ curl -OL https://github.com/PINTO0309/TensorflowLite-bin/releases/download/v2.11.0/mediapipe_customop_patch.zip \
&& unzip -d mediapipe_customop_patch mediapipe_customop_patch.zip \
&& git apply mediapipe_customop_patch/*

============================================================

Tensorflow v2.11.0 - Debian Bullseye aarch64 - Bazel 5.3.0 Using EC2 m6g.16xlarge

============================================================

$ wget -O bazel https://github.com/bazelbuild/bazel/releases/download/5.3.0/bazel-5.3.0-linux-arm64 \
&& sudo chmod 777 bazel \
&& sudo cp bazel /usr/local/bin \
&& sudo bazel clean --expunge \
&& ./configure

$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--ui_actions_shown=64 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

Tensorflow v2.11.0 - CUDA11.7 - TensorRT8.4.3 - x86_64 - Bazel 5.3.0

============================================================

$ wget https://github.com/bazelbuild/bazel/releases/download/5.3.0/bazel-5.3.0-installer-linux-x86_64.sh \
&& sudo chmod +x bazel-5.3.0-installer-linux-x86_64.sh \
&& sudo ./bazel-5.3.0-installer-linux-x86_64.sh \
&& sudo bazel clean --expunge
&& ./configure

supports compute capabilities >= 3.5 [Default is: 3.5,7.0]: 6.1,7.5,8.6

$ sudo bazel build \
--config=monolithic \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_pip_with_flex=true \
--define=tflite_with_xnnpack=true \
--ui_actions_shown=20 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

$ sudo ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.11.0*.whl ~
Tensorflow v2.12.0 https://zenn.dev/pinto0309/scraps/a735fde5301bdc
$ sudo apt update && sudo apt upgrade -y && \
sudo apt install -y \
  libhdf5-dev unzip pkg-config python3-pip \
  cmake make python-is-python3 wget && \
sudo pip3 install pip --upgrade && \
sudo pip3 install numpy==1.24.2 && \
sudo pip3 install keras_applications==1.0.8 --no-deps && \
sudo pip3 install keras_preprocessing==1.1.2 --no-deps && \
sudo pip3 install gdown h5py==3.6.0 && \
sudo pip3 install pybind11==2.9.2 && \
sudo pip3 install packaging && \
sudo pip3 install protobuf==3.20.3 && \
pip3 install -U --user six wheel mock

$ sed -i '15a #include <assert.h>' tensorflow/tsl/framework/fixedpoint/MatMatProductAVX2.h

============================================================

Tensorflow v2.12.0 - Debian 11/Debian 12/Ubuntu 20.04/22.04 aarch64 - Bazel 5.3.0 Using EC2 m6g.16xlarge

============================================================

$ wget -O bazel https://github.com/bazelbuild/bazel/releases/download/5.3.0/bazel-5.3.0-linux-arm64 \
&& sudo chmod 777 bazel \
&& sudo cp bazel /usr/local/bin \
&& sudo bazel clean --expunge \
&& ./configure

$ sudo bazel build \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_with_xnnpack=true \
--copt="-Wno-stringop-overflow" \
--ui_actions_shown=64 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

<INVALID> Tensorflow v2.12.0 - CUDA11.7 - TensorRT8.4.3 - x86_64 - Bazel 5.3.0

============================================================

$ wget https://github.com/bazelbuild/bazel/releases/download/5.3.0/bazel-5.3.0-installer-linux-x86_64.sh \
&& sudo chmod +x bazel-5.3.0-installer-linux-x86_64.sh \
&& sudo ./bazel-5.3.0-installer-linux-x86_64.sh \
&& sudo bazel clean --expunge
&& ./configure

# https://developer.nvidia.com/cuda-gpus
supports compute capabilities >= 3.5 [Default is: 3.5,7.0]: 8.6

$ sudo bazel build \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_with_xnnpack=true \
--define=with_xla_support=false \
--ui_actions_shown=20 \
//tensorflow/tools/pip_package:build_pip_package

============================================================

$ sudo ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.12.0*.whl ~
Tensorflow v2.15.0
# Bullseye, Ubuntu22.04
sudo apt update && sudo apt upgrade -y && \
sudo apt install -y \
    libhdf5-dev \
    unzip \
    pkg-config \
    python3-pip \
    cmake \
    make \
    git \
    python-is-python3 \
    wget \
    patchelf && \
pip install -U pip && \
pip install numpy==1.26.2 && \
pip install keras_applications==1.0.8 --no-deps && \
pip install keras_preprocessing==1.1.2 --no-deps && \
pip install h5py==3.6.0 && \
pip install pybind11==2.9.2 && \
pip install packaging && \
pip install protobuf==3.20.3 && \
pip install six wheel mock gdown

# Bookworm
sudo apt update && sudo apt upgrade -y && \
sudo apt install -y \
    libhdf5-dev \
    unzip \
    pkg-config \
    python3-pip \
    cmake \
    make \
    git \
    python-is-python3 \
    wget \
    patchelf && \
pip install -U pip --break-system-packages && \
pip install numpy==1.26.2 --break-system-packages && \
pip install keras_applications==1.0.8 --no-deps --break-system-packages && \
pip install keras_preprocessing==1.1.2 --no-deps --break-system-packages && \
pip install h5py==3.10.0 --break-system-packages && \
pip install pybind11==2.9.2 --break-system-packages && \
pip install packaging --break-system-packages && \
pip install protobuf==3.20.3 --break-system-packages && \
pip install six wheel mock gdown --break-system-packages


git clone -b r2.15-tflite-build https://github.com/PINTO0309/tensorflow.git
cd tensorflow

export TF_PYTHON_VERSION=3.xx

wget -O bazel https://github.com/bazelbuild/bazel/releases/download/6.1.0/bazel-6.1.0-linux-arm64 \
&& sudo chmod 777 bazel \
&& sudo cp bazel /usr/local/bin \
&& sudo bazel clean --expunge \
&& ./configure

bazel build \
--config=noaws \
--config=nohdfs \
--config=nonccl \
--config=v2 \
--define=tflite_with_xnnpack=true \
--define=xnnpack_force_float_precision=fp16 \
--copt="-Wno-stringop-overflow" \
--ui_actions_shown=64 \
//tensorflow/tools/pip_package:build_pip_package
sudo ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
sudo cp /tmp/tensorflow_pkg/tensorflow-2.15.0*.whl ~

Reference articles

tensorflow-bin's People

Contributors

doom4535 avatar halilsafakkilic avatar nobuotsukamoto avatar pinto0309 avatar rhenerose avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tensorflow-bin's Issues

Get Error Message When Import Tensorflow 2.5

[Required] Your device : Raspberry Pi 3

[Required] Your device's CPU architecture : armv7l

[Required] Your OS : Raspbian i think

[Required] Details of the work you did before the problem occurred: I want to install tensorflow 2.5, and after install it (there is tensorflow 2.5 when i call "pip3 list"), there is error message, so i can'nt import it

[Required] Error message: >>> import tensorflow
RuntimeError: module compiled against API version 0xe but this version of numpy is 0xd
Traceback (most recent call last):
File "", line 1, in
File "/usr/local/lib/python3.7/dist-packages/tensorflow/init.py", line 41, in
from tensorflow.python.tools import module_util as _module_util
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/init.py", line 40, in
from tensorflow.python.eager import context
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/context.py", line 37, in
from tensorflow.python.client import pywrap_tf_session
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/client/pywrap_tf_session.py", line 23, in
from tensorflow.python._pywrap_tf_session import *
ImportError: SystemError: <built-in method contains of dict object at 0x6fc25840> returned a result with an error set

2021-06-20-170108_1440x900_scrot

[Required] Overview of problems and questions: How i solve this problem?
i follow this tutorial to install tensorflow in my raspberry pi 3
2021-06-20-170430_1440x900_scrot

My future plan is want to make custom object detection with tensorflow lite in my raspberry pi 3

tensorflow-1.11.0-cp36-cp36m-linux_aarch64.whl failed.

Successfully installed tensorboard-1.11.0 tensorflow-1.11.0
[toybrick@localhost work]$ python3
Python 3.6.8 (default, Jan 31 2019, 09:06:51)
[GCC 8.2.1 20181215 (Red Hat 8.2.1-6)] on linux
Type "help", "copyright", "credits" or "license" for more information.

import tensorflow as tf
Traceback (most recent call last):
File "/home/toybrick/.local/lib/python3.6/site-packages/tensorflow/python/pywrap_tensorflow.py", line 58, in
from tensorflow.python.pywrap_tensorflow_internal import *
File "/home/toybrick/.local/lib/python3.6/site-packages/tensorflow/python/pywrap_tensorflow_internal.py", line 28, in
_pywrap_tensorflow_internal = swig_import_helper()
File "/home/toybrick/.local/lib/python3.6/site-packages/tensorflow/python/pywrap_tensorflow_internal.py", line 24, in swig_import_helper
_mod = imp.load_module('_pywrap_tensorflow_internal', fp, pathname, description)
File "/usr/lib64/python3.6/imp.py", line 243, in load_module
return load_dynamic(name, filename, file)
File "/usr/lib64/python3.6/imp.py", line 343, in load_dynamic
return _load(spec)
ImportError: libcublas.so.9.0: cannot open shared object file: No such file or directory

TF 2.7.0 fails on Raspberry Pi 4 64-OS

Your device: Raspberry Pi 4

Your device's CPU architecture: aarch64

Your OS: Raspberry Pi OS Buster (64-bit)

Details of the work you did before the problem occurred:

  • Follow the Usage: "Example of Python 3.x + Tensorflow v2 series" steps up to the wget step
  • From the the wget step on, use the tensorflow-2.7.0-cp37-none-linux_aarch64_numpy1214_download.sh script instead

Error message:

ERROR: Could not find a version that satisfies the requirement libclang>= 9.0.1 (from tensorflow)

image

### Overview of problems and questions:
Tensorflow 2.7 seems to depend on the libclang 9.0.1 package
That version doesn't seem to be available for arm versions (aarch64 and armv7l) via pi.
Only the unsuitable version 7 can be installed with $ sudo apt-get install libclang
What is the solution here? TF 2.6 only optionally depending on clang.

ERROR: tensorflow-2.5.0-cp 37-none-linux_armv71.whl is not a supported wheel on this platform.

Issue Type

Bug

OS

RaspberryPi OS Buster

OS architecture

armv7

Hardware

RaspberryPi4

Description

When doing pip install tensorflow-2.5.0-cp 37-none-linux_armv71.whl is not a supported wheel on this platform.
getting not a supported wheel on this platform my python version is 3.9.2

Relevant Log Output

ERROR: tensorflow-2.5.0-cp 37-none-linux_armv71.whl is not a supported wheel on this platform.

Dependency issue with tensorflow 2.7 on python 3.9

Your device: Raspberry Pi 3

Your device's CPU architecture: aarch64

Your OS: Raspberry Pi OS Bullseye (64-bit)

Details of the work you did before the problem occurred:

  1. Follow the Usage: "Example of Python 3.x + Tensorflow v2 series" steps up to the wget step
  2. From the the wget step on, use the tensorflow-2.7.0-cp39-none-linux_aarch64_numpy1214_download.sh script instead

Error message:

ERROR: Could not find a version that satisfies the requirement tensorflow-io-gcs-filesystem>=0.21.0 (from tensorflow)
ERROR: No matching distribution found for tensorflow-io-gcs-filesystem>=0.21.0

Overview of problems and questions:

  • Tensorflow 2.7 seems to depend on the tensorflow-io-gcs-filesystem package
  • That package doesn't seem to be available for arm versions (aarch64 and armv7l) via pip if I didn't miss something
  • Is there way to get around this, maybe by building it manually or including it in the wheel?

Tensorflow with python3.9 for armhf/armv7l

Issue Type

Feature Request, Others

OS

Other

OS architecture

armv7

Hardware

RaspberryPi4, RaspberryPi3

Description

I am trying to get tensorflow running on the newly released Raspberry Pi OS Bullseye in the 32 bit / armv7l / armhf edition. I need the 32 bit version due to compatibility reasons. Since this new release comes with python3.9, there are no matching wheels provided. All wheels using python3.9 (*-cp39-*) are also aarch64 only.

Is there a reason why those versions are incompatible with armv7l?

And if not, could you consider adding a wheel that works with python3.9 and armv7l?

Relevant Log Output

-

'GLIBC_2.29' error not found for Python 3.7 + Tensorflow v2

Thank you and I was going to install Python 3.7 + Tensorflow v2 on my raspberry pi 4, Buster.

Everything went well and pip3 list can see the v2 tensorflow. But when import the tf in python3, below error happened:

Traceback (most recent call last):
File "/home/pi/.local/lib/python3.7/site-packages/tensorflow_core/python/pywrap_tensorflow.py", line 58, in
from tensorflow.python.pywrap_tensorflow_internal import *
File "/home/pi/.local/lib/python3.7/site-packages/tensorflow_core/python/pywrap_tensorflow_internal.py", line 28, in
_pywrap_tensorflow_internal = swig_import_helper()
File "/home/pi/.local/lib/python3.7/site-packages/tensorflow_core/python/pywrap_tensorflow_internal.py", line 24, in swig_import_helper
_mod = imp.load_module('_pywrap_tensorflow_internal', fp, pathname, description)
File "/usr/lib/python3.7/imp.py", line 242, in load_module
return load_dynamic(name, filename, file)
File "/usr/lib/python3.7/imp.py", line 342, in load_dynamic
return _load(spec)
ImportError: /lib/arm-linux-gnueabihf/libm.so.6: version 'GLIBC_2.29' not found (required by /home/pi/.local/lib/python3.7/site-packages/tensorflow_core/python/_pywrap_tensorflow_internal.so)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "", line 1, in
File "/home/pi/.local/lib/python3.7/site-packages/tensorflow/init.py", line 98, in
from tensorflow_core import *
File "/home/pi/.local/lib/python3.7/site-packages/tensorflow_core/init.py", line 40, in
from tensorflow.python.tools import module_util as _module_util
File "/home/pi/.local/lib/python3.7/site-packages/tensorflow/init.py", line 50, in getattr
module = self._load()
File "/home/pi/.local/lib/python3.7/site-packages/tensorflow/init.py", line 44, in _load
module = _importlib.import_module(self.name)
File "/usr/lib/python3.7/importlib/init.py", line 127, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "/home/pi/.local/lib/python3.7/site-packages/tensorflow_core/python/init.py", line 49, in
from tensorflow.python import pywrap_tensorflow
File "/home/pi/.local/lib/python3.7/site-packages/tensorflow_core/python/pywrap_tensorflow.py", line 74, in
raise ImportError(msg)
ImportError: Traceback (most recent call last):
File "/home/pi/.local/lib/python3.7/site-packages/tensorflow_core/python/pywrap_tensorflow.py", line 58, in
from tensorflow.python.pywrap_tensorflow_internal import *
File "/home/pi/.local/lib/python3.7/site-packages/tensorflow_core/python/pywrap_tensorflow_internal.py", line 28, in
_pywrap_tensorflow_internal = swig_import_helper()
File "/home/pi/.local/lib/python3.7/site-packages/tensorflow_core/python/pywrap_tensorflow_internal.py", line 24, in swig_import_helper
_mod = imp.load_module('_pywrap_tensorflow_internal', fp, pathname, description)
File "/usr/lib/python3.7/imp.py", line 242, in load_module
return load_dynamic(name, filename, file)
File "/usr/lib/python3.7/imp.py", line 342, in load_dynamic
return _load(spec)
ImportError: /lib/arm-linux-gnueabihf/libm.so.6: version `GLIBC_2.29' not found (required by /home/pi/.local/lib/python3.7/site-packages/tensorflow_core/python/_pywrap_tensorflow_internal.so)

Failed to load the native TensorFlow runtime.

See https://www.tensorflow.org/install/errors

for some common reasons and solutions. Include the entire stack trace
above this error message when asking for help.

I tried the 【Appendix】 C Library + Tensorflow v1.x.x / v2.x.x you provided but not helping.

Any good idea how to go on?

stop working

i tried to install tensorflow in raspberry pi 3, but when build parameter" sudo bazel build --config ...." the process stop working while process on 4233/8000s.

what should i do?

thank you

Allocation of 4000000 exceeds 10% of system memory

[Required] Your device (RaspberryPi3, LaptopPC, or other device name): pi4

[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name): armv7l

[Required] Your OS (Raspbian, Ubuntu1604, or other os name): raspbian

[Required] Details of the work you did before the problem occurred:

when I exec this test code:

python3 -c "import tensorflow as tf; tf.enable_eager_execution(); print(tf.reduce_sum(tf.random_normal([1000, 1000])))"

. I get the error:

Allocation of 4000000 exceeds 10% of system memory


[Required] Error message:



Allocation of 4000000 exceeds 10% of system memory


[Required] Overview of problems and questions:





The command "vcgencmd get_mem gpu" shows:
gpu=256M
The command "vcgencmd get_mem arm" shows:
arm=768M

My pi has 4G memory, I don't know why it's only 768M.

Struggling with numpy version

pip3 install tensorflow-2.5.0-cp37-none-linux_aarch64.whl
Looking in indexes: https://pypi.org/simple, https://www.piwheels.org/simple
Processing ./tensorflow-2.5.0-cp37-none-linux_aarch64.whl
Collecting opt-einsum~=3.3.0
  Downloading https://www.piwheels.org/simple/opt-einsum/opt_einsum-3.3.0-py3-none-any.whl (65 kB)
     |████████████████████████████████| 65 kB 654 kB/s 
Collecting tensorflow-estimator<2.6.0,>=2.5.0rc0
  Downloading tensorflow_estimator-2.5.0-py2.py3-none-any.whl (462 kB)
     |████████████████████████████████| 462 kB 113 kB/s 
Requirement already satisfied: absl-py~=0.10 in ./venv/lib/python3.7/site-packages (from tensorflow==2.5.0) (0.12.0)
Collecting google-pasta~=0.2
  Downloading https://www.piwheels.org/simple/google-pasta/google_pasta-0.2.0-py3-none-any.whl (57 kB)
     |████████████████████████████████| 57 kB 820 kB/s 
Collecting wheel~=0.35
  Downloading https://www.piwheels.org/simple/wheel/wheel-0.36.2-py2.py3-none-any.whl (35 kB)
Collecting keras-preprocessing~=1.1.2
  Downloading https://www.piwheels.org/simple/keras-preprocessing/Keras_Preprocessing-1.1.2-py2.py3-none-any.whl (42 kB)
     |████████████████████████████████| 42 kB 209 kB/s 
Collecting wrapt~=1.12.1
  Downloading wrapt-1.12.1.tar.gz (27 kB)
Collecting tensorboard~=2.5
  Downloading tensorboard-2.5.0-py3-none-any.whl (6.0 MB)
     |████████████████████████████████| 6.0 MB 123 kB/s 
Collecting termcolor~=1.1.0
  Downloading https://www.piwheels.org/simple/termcolor/termcolor-1.1.0-py3-none-any.whl (4.8 kB)
Collecting protobuf>=3.9.2
  Downloading protobuf-3.17.0-cp37-cp37m-manylinux2014_aarch64.whl (925 kB)
     |████████████████████████████████| 925 kB 374 kB/s 
Collecting flatbuffers~=1.12.0
  Downloading https://www.piwheels.org/simple/flatbuffers/flatbuffers-1.12-py2.py3-none-any.whl (15 kB)
Requirement already satisfied: numpy~=1.19.2 in ./venv/lib/python3.7/site-packages (from tensorflow==2.5.0) (1.19.5)
Collecting h5py~=3.1.0
  Downloading h5py-3.1.0.tar.gz (371 kB)
     |████████████████████████████████| 371 kB 799 kB/s 
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Installing backend dependencies ... done
    Preparing wheel metadata ... done
Collecting grpcio~=1.34.0
  Downloading grpcio-1.34.1.tar.gz (21.1 MB)
     |████████████████████████████████| 21.1 MB 23 kB/s 
Collecting typing-extensions~=3.7.4
  Downloading https://www.piwheels.org/simple/typing-extensions/typing_extensions-3.7.4.3-py3-none-any.whl (22 kB)
Collecting astunparse~=1.6.3
  Downloading https://www.piwheels.org/simple/astunparse/astunparse-1.6.3-py2.py3-none-any.whl (12 kB)
Collecting keras-nightly~=2.5.0.dev
  Downloading keras_nightly-2.5.0.dev2021032900-py2.py3-none-any.whl (1.2 MB)
     |████████████████████████████████| 1.2 MB 462 kB/s 
Collecting gast==0.4.0
  Downloading https://www.piwheels.org/simple/gast/gast-0.4.0-py3-none-any.whl (9.9 kB)
Collecting six~=1.15.0
  Downloading https://www.piwheels.org/simple/six/six-1.15.0-py2.py3-none-any.whl (10 kB)
Collecting cached-property
  Downloading https://www.piwheels.org/simple/cached-property/cached_property-1.5.2-py2.py3-none-any.whl (7.6 kB)
Collecting google-auth<2,>=1.6.3
  Downloading https://www.piwheels.org/simple/google-auth/google_auth-1.30.0-py2.py3-none-any.whl (146 kB)
     |████████████████████████████████| 146 kB 345 kB/s 
Collecting google-auth-oauthlib<0.5,>=0.4.1
  Downloading https://www.piwheels.org/simple/google-auth-oauthlib/google_auth_oauthlib-0.4.4-py2.py3-none-any.whl (18 kB)
Requirement already satisfied: requests<3,>=2.21.0 in /usr/lib/python3/dist-packages (from tensorboard~=2.5->tensorflow==2.5.0) (2.21.0)
Collecting tensorboard-plugin-wit>=1.6.0
  Downloading tensorboard_plugin_wit-1.8.0-py3-none-any.whl (781 kB)
     |████████████████████████████████| 781 kB 415 kB/s 
Collecting werkzeug>=0.11.15
  Downloading https://www.piwheels.org/simple/werkzeug/Werkzeug-2.0.0-py3-none-any.whl (288 kB)
     |████████████████████████████████| 288 kB 555 kB/s 
Requirement already satisfied: setuptools>=41.0.0 in ./venv/lib/python3.7/site-packages (from tensorboard~=2.5->tensorflow==2.5.0) (56.2.0)
Collecting markdown>=2.6.8
  Downloading https://www.piwheels.org/simple/markdown/Markdown-3.3.4-py3-none-any.whl (97 kB)
     |████████████████████████████████| 97 kB 153 kB/s 
Collecting tensorboard-data-server<0.7.0,>=0.6.0
  Downloading tensorboard_data_server-0.6.1-py3-none-any.whl (2.4 kB)
Collecting rsa<5,>=3.1.4
  Downloading https://www.piwheels.org/simple/rsa/rsa-4.7.2-py3-none-any.whl (34 kB)
Collecting cachetools<5.0,>=2.0.0
  Downloading https://www.piwheels.org/simple/cachetools/cachetools-4.2.2-py3-none-any.whl (11 kB)
Collecting pyasn1-modules>=0.2.1
  Downloading https://www.piwheels.org/simple/pyasn1-modules/pyasn1_modules-0.2.8-py2.py3-none-any.whl (155 kB)
     |████████████████████████████████| 155 kB 560 kB/s 
Collecting requests-oauthlib>=0.7.0
  Downloading https://www.piwheels.org/simple/requests-oauthlib/requests_oauthlib-1.3.0-py2.py3-none-any.whl (23 kB)
Collecting importlib-metadata
  Downloading https://www.piwheels.org/simple/importlib-metadata/importlib_metadata-4.0.1-py3-none-any.whl (16 kB)
Collecting pyasn1<0.5.0,>=0.4.6
  Downloading https://www.piwheels.org/simple/pyasn1/pyasn1-0.4.8-py2.py3-none-any.whl (77 kB)
     |████████████████████████████████| 77 kB 362 kB/s 
Collecting oauthlib>=3.0.0
  Downloading https://www.piwheels.org/simple/oauthlib/oauthlib-3.1.0-py2.py3-none-any.whl (147 kB)
     |████████████████████████████████| 147 kB 249 kB/s 
Collecting zipp>=0.5
  Downloading https://www.piwheels.org/simple/zipp/zipp-3.4.1-py3-none-any.whl (5.2 kB)
Building wheels for collected packages: grpcio, h5py, wrapt
  Building wheel for grpcio (setup.py) ... done
  Created wheel for grpcio: filename=grpcio-1.34.1-cp37-cp37m-linux_aarch64.whl size=36263770 sha256=a3b926ab6d28e4499d7efa240f1164bc86434883e3afbda27bac57979f3aa593
  Stored in directory: /home/pi/.cache/pip/wheels/7c/54/6a/ba879c1e8a943659b0a31178a458a8f119531aa6d548081e66
  Building wheel for h5py (PEP 517) ... done
  Created wheel for h5py: filename=h5py-3.1.0-cp37-cp37m-linux_aarch64.whl size=5351491 sha256=1ebe5a19383d8e2d29e5d5e6f69ece0b0f28a98e3dda5dc8e75e0d23a199f42b
  Stored in directory: /home/pi/.cache/pip/wheels/8b/ff/b0/8ec15768fa86bc9635867a6acbbb6f203c82f264afd667bb36
  Building wheel for wrapt (setup.py) ... done
  Created wheel for wrapt: filename=wrapt-1.12.1-cp37-cp37m-linux_aarch64.whl size=72884 sha256=6a8f019dd848a8206419ebd4da28e6ffa4435cd81046d09a11877754adb6383a
  Stored in directory: /home/pi/.cache/pip/wheels/62/76/4c/aa25851149f3f6d9785f6c869387ad82b3fd37582fa8147ac6
Successfully built grpcio h5py wrapt
Installing collected packages: pyasn1, zipp, typing-extensions, six, rsa, pyasn1-modules, oauthlib, cachetools, requests-oauthlib, importlib-metadata, google-auth, wheel, werkzeug, tensorboard-plugin-wit, tensorboard-data-server, protobuf, markdown, grpcio, google-auth-oauthlib, cached-property, wrapt, termcolor, tensorflow-estimator, tensorboard, opt-einsum, keras-preprocessing, keras-nightly, h5py, google-pasta, gast, flatbuffers, astunparse, tensorflow
  Attempting uninstall: six
    Found existing installation: six 1.12.0
    Not uninstalling six at /usr/lib/python3/dist-packages, outside environment /home/pi/g-kws/venv
    Can't uninstall 'six'. No files were found to uninstall.
  Attempting uninstall: wheel
    Found existing installation: wheel 0.32.3
    Not uninstalling wheel at /usr/lib/python3/dist-packages, outside environment /home/pi/g-kws/venv
    Can't uninstall 'wheel'. No files were found to uninstall.
  Attempting uninstall: h5py
    Found existing installation: h5py 2.10.0
    Uninstalling h5py-2.10.0:
      Successfully uninstalled h5py-2.10.0
Successfully installed astunparse-1.6.3 cached-property-1.5.2 cachetools-4.2.2 flatbuffers-1.12 gast-0.4.0 google-auth-1.30.0 google-auth-oauthlib-0.4.4 google-pasta-0.2.0 grpcio-1.34.1 h5py-3.1.0 importlib-metadata-4.0.1 keras-nightly-2.5.0.dev2021032900 keras-preprocessing-1.1.2 markdown-3.3.4 oauthlib-3.1.0 opt-einsum-3.3.0 protobuf-3.17.0 pyasn1-0.4.8 pyasn1-modules-0.2.8 requests-oauthlib-1.3.0 rsa-4.7.2 six-1.15.0 tensorboard-2.5.0 tensorboard-data-server-0.6.1 tensorboard-plugin-wit-1.8.0 tensorflow-2.5.0 tensorflow-estimator-2.5.0 termcolor-1.1.0 typing-extensions-3.7.4.3 werkzeug-2.0.0 wheel-0.36.2 wrapt-1.12.1 zipp-3.4.1
(venv) pi@raspberrypi:~/g-kws $ python3
Python 3.7.3 (default, Jan 22 2021, 20:04:44) 
[GCC 8.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow
RuntimeError: module compiled against API version 0xe but this version of numpy is 0xd
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/pi/g-kws/venv/lib/python3.7/site-packages/tensorflow/__init__.py", line 41, in <module>
    from tensorflow.python.tools import module_util as _module_util
  File "/home/pi/g-kws/venv/lib/python3.7/site-packages/tensorflow/python/__init__.py", line 40, in <module>
    from tensorflow.python.eager import context
  File "/home/pi/g-kws/venv/lib/python3.7/site-packages/tensorflow/python/eager/context.py", line 37, in <module>
    from tensorflow.python.client import pywrap_tf_session
  File "/home/pi/g-kws/venv/lib/python3.7/site-packages/tensorflow/python/client/pywrap_tf_session.py", line 23, in <module>
    from tensorflow.python._pywrap_tf_session import *
ImportError: SystemError: <built-in method __contains__ of dict object at 0x7f98e18630> returned a result with an error set

issue on wget "https://raw.githubusercontent.com/PINTO0309/Tensorflow-bin/master/tensorflow-2.6.0-cp37-none-linux_aarch64_download.sh"

[Required] Your device (RaspberryPi3, LaptopPC, or other device name):
RaspberryPi3 B+
[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name):
armv7l
[Required] Your OS (Raspbian, Ubuntu1604, or other os name):
Raspbian
[Required] Details of the work you did before the problem occurred:
I tried $ wget "https://raw.githubusercontent.com/PINTO0309/Tensorflow-bin/master/tensorflow-2.6.0-cp37-none-linux_aarch64_download.sh" but it show error 404 can't find the path
[Required] Error message:
HTTP request sent, awaiting response... 404 Not Found
[Required] Overview of problems and questions:
Not sure why the path link can't found

TF2.5 installation issue

Does anyone know what could cause the installation of TF2.5 to force installing TF 2.4rc0 (see installation output below)?

Thanks, Filip

Processing ./tensorflow-2.5.0-cp38-none-linux_aarch64.whl
Requirement already satisfied: protobuf~=3.13.0 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (3.13.0)
Requirement already satisfied: grpcio~=1.32.0 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (1.32.0)
Requirement already satisfied: keras-preprocessing~=1.1.2 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (1.1.2)
Requirement already satisfied: google-pasta~=0.2 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (0.2.0)
Requirement already satisfied: six~=1.15.0 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (1.15.0)
Requirement already satisfied: wheel~=0.35 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (0.36.2)
Requirement already satisfied: absl-py~=0.10 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (0.12.0)
Requirement already satisfied: wrapt~=1.12.1 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (1.12.1)
Requirement already satisfied: termcolor~=1.1.0 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (1.1.0)
Requirement already satisfied: tensorboard~=2.3 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (2.4.1)
Requirement already satisfied: flatbuffers~=1.12.0 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (1.12)
Requirement already satisfied: opt-einsum~=3.3.0 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (3.3.0)
Requirement already satisfied: typing-extensions~=3.7.4 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (3.7.4.3)
Requirement already satisfied: astunparse~=1.6.3 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (1.6.3)
Requirement already satisfied: gast==0.3.3 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (0.3.3)
Requirement already satisfied: numpy~=1.19.2 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (1.19.5)
Requirement already satisfied: h5py~=2.10.0 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (2.10.0)
Requirement already satisfied: tensorflow-estimator~=2.3.0 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorflow==2.5.0) (2.3.0)
Requirement already satisfied: setuptools in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from protobuf~=3.13.0->tensorflow==2.5.0) (44.0.0)
Requirement already satisfied: werkzeug>=0.11.15 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorboard~=2.3->tensorflow==2.5.0) (1.0.1)
Requirement already satisfied: markdown>=2.6.8 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorboard~=2.3->tensorflow==2.5.0) (3.3.4)
Requirement already satisfied: google-auth<2,>=1.6.3 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorboard~=2.3->tensorflow==2.5.0) (1.27.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorboard~=2.3->tensorflow==2.5.0) (1.8.0)
Requirement already satisfied: requests<3,>=2.21.0 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorboard~=2.3->tensorflow==2.5.0) (2.25.1)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from tensorboard~=2.3->tensorflow==2.5.0) (0.4.3)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.3->tensorflow==2.5.0) (4.2.1)
Requirement already satisfied: rsa<5,>=3.1.4 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.3->tensorflow==2.5.0) (4.7.2)
Requirement already satisfied: pyasn1-modules>=0.2.1 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.3->tensorflow==2.5.0) (0.2.8)
Requirement already satisfied: requests-oauthlib>=0.7.0 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.3->tensorflow==2.5.0) (1.3.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard~=2.3->tensorflow==2.5.0) (0.4.8)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard~=2.3->tensorflow==2.5.0) (1.26.3)
Requirement already satisfied: certifi>=2017.4.17 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard~=2.3->tensorflow==2.5.0) (2020.12.5)
Requirement already satisfied: idna<3,>=2.5 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard~=2.3->tensorflow==2.5.0) (2.10)
Requirement already satisfied: chardet<5,>=3.0.2 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard~=2.3->tensorflow==2.5.0) (4.0.0)
Requirement already satisfied: oauthlib>=3.0.0 in ./python_virtual_envs/camera_env/lib/python3.8/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.3->tensorflow==2.5.0) (3.1.0)
Installing collected packages: tensorflow
Successfully installed tensorflow-2.4.0rc0

Issue in running inference with tf2.5 and tf2.3 on armv7l

Issue Type

Support

OS

RaspberryPi OS Buster

OS architecture

armv7

Hardware

RaspberryPi3

Description

I have trained a TRILL model in tf2.5.0 and have converted it into tflite by the following lines:

converter = tf.lite.TFLiteConverter.from_saved_model(export_dir)
converter._experimental_lower_tensor_list_ops = False
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()

converter.optimizations = [tf.lite.Optimize.DEFAULT]

converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()

open("TRILL_quant.tflite", "wb").write(tflite_quant_model)

The inference is working fine when tested on a colab environment

But when I try to run the same inference with the same model, I'm getting this error

I've tried tensorflow 2.3 and 2.5 from your repository

Relevant Log Output

RuntimeError: Regular TensorFlow ops are not supported by this interpreter. Make sure you apply/link the Flex delegate before inference.Node number 2 (FlexTensorListReserve) failed to prepare.

TF-2.3.0-cp38-cp38m-linux_aarch64.whl is not supported wheel

device: RaspberryPi3 B+
CPU architecture : aarch64
OS: Ubuntu 20.04 64bit
Work you did before the problem occurred: Clean installation of Ubuntu 20.04
Error message: tensorflow-2.3.0-cp38-cp38m-linux_aarch64.whl is not a supported wheel on this platform.

Not able to install the wheel due to error above, even if all conditions look correct.

it hangs after a while (RPI3B)

I ran through the steps:

$ sudo apt-get install python-pip python3-pip python3-scipy libhdf5-dev
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo pip3 uninstall tensorflow
$ wget -O tensorflow-1.11.0-cp35-cp35m-linux_armv7l.whl https://github.com/PINTO0309/Tensorflow-bin/raw/master/tensorflow-1.11.0-cp35-cp35m-linux_armv7l_jemalloc.whl
$ sudo pip3 install tensorflow-1.11.0-cp35-cp35m-linux_armv7l.whl

And this has been the only package that has allow me to load a tflite model, thanks for that.
But i notice that is taking only 25% of the CPU is using 1 thread only making it slower than using regular tensorflow package with normal model.

so i tried to install:

tensorflow-1.11.0-cp35-cp35m-linux_armv7l_jemalloc.whl
tensorflow-1.11.0-cp35-cp35m-linux_armv7l_jemalloc_mpi.whl
tensorflow-1.11.0-cp35-cp35m-linux_armv7l_jemalloc_mpi_multithread.whl

but non of them work as i got:
*.whl is not a supported wheel on this platform.

i am using a RTSP camera to get the images and script hangs after a while with no response i need to end the process manually something that doesnt happen with regular tensorflow.

am i doing something wrong? what am i missing i was expecting better results.

pip3 install tensorflow package:
Tensorflow 1.11 = 0.7 FPS (frozen.pb model)

pip3 install tensorflow-1.11.0-cp35-cp35m-linux_armv7l.whl
Tensorflow 1.11 = 0.4 FPS (frozen.pb model)
Tensorflow Lite 1.11 = 0.2 FPS (frozen.tflite model)

import error on jetson nano

[Required] Your device (RaspberryPi3, LaptopPC, or other device name):
Jetson Nano

[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name):
aarch64

[Required] Your OS (Raspbian, Ubuntu1604, or other os name):
Ubuntu 18.04 (Latest jetson nano sd image - JP4.2.1)

[Required] Details of the work you did before the problem occurred:
Installed provided package:
https://github.com/PINTO0309/Tensorflow-bin/blob/master/tensorflow-2.0.0b1-cp37-cp37m-linux_aarch64.whl

[Required] Error message:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/srv/maverick/software/python/lib/python3.7/site-packages/tensorflow/__init__.py", line 28, in <module>
    from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
  File "/srv/maverick/software/python/lib/python3.7/site-packages/tensorflow/python/__init__.py", line 49, in <module>
    from tensorflow.python import pywrap_tensorflow
  File "/srv/maverick/software/python/lib/python3.7/site-packages/tensorflow/python/pywrap_tensorflow.py", line 74, in <module>
    raise ImportError(msg)
ImportError: Traceback (most recent call last):
  File "/srv/maverick/software/python/lib/python3.7/site-packages/tensorflow/python/pywrap_tensorflow.py", line 58, in <module>
    from tensorflow.python.pywrap_tensorflow_internal import *
  File "/srv/maverick/software/python/lib/python3.7/site-packages/tensorflow/python/pywrap_tensorflow_internal.py", line 28, in <module>
    _pywrap_tensorflow_internal = swig_import_helper()
  File "/srv/maverick/software/python/lib/python3.7/site-packages/tensorflow/python/pywrap_tensorflow_internal.py", line 24, in swig_import_helper
    _mod = imp.load_module('_pywrap_tensorflow_internal', fp, pathname, description)
  File "/srv/maverick/software/python/lib/python3.7/imp.py", line 242, in load_module
    return load_dynamic(name, filename, file)
  File "/srv/maverick/software/python/lib/python3.7/imp.py", line 342, in load_dynamic
    return _load(spec)
ImportError: /lib/aarch64-linux-gnu/libc.so.6: version `GLIBC_2.28' not found (required by /srv/maverick/software/python/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so)

tf 1.15.0 and tensor2tensor on raspberry pi -- get_data_files problem

[Required] Your device (RaspberryPi3, LaptopPC, or other device name): raspberry pi

[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name): armv7l

[Required] Your OS (Raspbian, Ubuntu1604, or other os name): raspbian

[Required] Details of the work you did before the problem occurred: designed a t2t model

[Required] Error message: see below

[Required] Overview of problems and questions:

I am seeing this error. I'm using your tf 1.15.0 for armhf and python3.7. Ultimately I will be using raspbian and buster. I also use tensor2tensor and this is where I see my error now.

File "/usr/local/lib/python3.7/dist-packages/tensor2tensor/bin/t2t_decoder.py", line 209, in main
decode(estimator, hp, decode_hp)
File "/usr/local/lib/python3.7/dist-packages/tensor2tensor/bin/t2t_decoder.py", line 110, in decode
checkpoint_path=FLAGS.checkpoint_path)
File "/usr/local/lib/python3.7/dist-packages/tensor2tensor/utils/decoding.py", line 227, in decode_from_dataset
checkpoint_path=checkpoint_path)
File "/usr/local/lib/python3.7/dist-packages/tensor2tensor/utils/decoding.py", line 316, in decode_once
for num_predictions, prediction in enumerate(predictions):
File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 620, in predict
input_fn, ModeKeys.PREDICT)
File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 996, in _get_features_from_input_fn
result = self._call_input_fn(input_fn, mode)
File "/usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1116, in _call_input_fn
return input_fn(**kwargs)
File "/usr/local/lib/python3.7/dist-packages/tensor2tensor/data_generators/problem.py", line 803, in estimator_input_fn
dataset_kwargs=dataset_kwargs)
File "/usr/local/lib/python3.7/dist-packages/tensor2tensor/data_generators/problem.py", line 888, in input_fn
self.dataset(*dataset_kwargs),
File "/usr/local/lib/python3.7/dist-packages/tensor2tensor/data_generators/problem.py", line 653, in dataset
contrib.slim().parallel_reader.get_data_files(data_filepattern))
File "/usr/local/lib/python3.7/dist-packages/tensorflow_core/contrib/slim/python/slim/data/parallel_reader.py", line 316, in get_data_files
raise ValueError('No data files found in %s' % (data_sources,))
ValueError: No data files found in /app/transformer/../data/t2t_data/chat_movie_30/chat_line_problem-dev

The error says that the 'dev' files don't exist but in fact they do. I don't know how to fix this. In the printout above the error shows up in the test phase. I can run the t2t code in the train phase and the code will save the data files but later when it comes time to read them it crashes again.

Error executing pip3 install tensorflow-2.8.0-cp39-none-linux_aarch64.whl

Issue Type

Support

OS

Other

OS architecture

aarch64

Hardware

Other

Description

i have the problem describe in the tittle, when i execute pip3 install tensorflow-2.8.0-cp39-none-linux_aarch64.whl in my raspberry pi 3b +, with this especs

PRETTY_NAME="Debian GNU/Linux 11 (bullseye)"
NAME="Debian GNU/Linux"
VERSION_ID="11"
VERSION="11 (bullseye)"
VERSION_CODENAME=bullseye
ID=debian
i got this message in the bash:
ERROR: Wheel 'tensorflow' located at /home/pi/Project/tensorflow-2.8.0-cp39-none-linux_aarch64.whl is invalid.

please help

i have install this dependencies before execute the command in a virtualenv:
pip list

Package Version


h5py 3.1.0
Keras-Applications 1.0.8
Keras-Preprocessing 1.1.0
mock 4.0.3
numpy 1.22.1
pip 22.0.3
pybind11 2.9.1
setuptools 60.6.0
six 1.16.0
wheel 0.37.1

Relevant Log Output

$ pip3 install tensorflow-2.8.0-cp39-none-linux_aarch64.whl

ERROR: Wheel 'tensorflow' located at /home/pi/Project/tensorflow-2.8.0-cp39-none-linux_aarch64.whl is invalid.

THANK YOU!!!

[Required] Your device (RaspberryPi3, LaptopPC, or other device name):
Pi Zero W

[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name):
ARMV6?

[Required] Your OS (Raspbian, Ubuntu1604, or other os name):
Buster

[Required] Details of the work you did before the problem occurred:
I looked around the internet for 2 hours trying to figure out how to get TFlite onto a pi. Every tutorial involved complicated, time-consuming cross-compilation. Than you for making it easier.
[Required] Error message:
None
[Required] Overview of problems and questions:
It looks like you already did the job of cross-compiling. If you don't mind, in beginner-speak, could you explain how to use and install these things onto a Pi4b or a Pi Zero? Thank you!

Error after installing tensorflow

NVIDIA-Drive PX2

[Required] Details of the work you did before the problem occurred:

Installed tensorflow, pip list |grep tensorflow shows the correct version

I had set the environment variables:LD_LIBRARY_PATH as /usr/local/cuda-9.2/lib64


[Required] Error message:

ImportError: libcublas.so.9.0: cannot open shared object file: No such file or directory




[Required] Overview of problems and questions:





Using tensorflow-2.5.0-cp37-none-linux_armv7l_numpy1200_download.sh downloads .whl with numpy~=1.19.2 in requirements

[Device] Your device (RaspberryPi3, LaptopPC, or other device name):
RaspberryPi4

[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name):
armv7l

[Required] Your OS (Raspbian, Ubuntu1604, or other os name):
Raspbian GNU/Linux 10 (buster)

[Required] Details of the work you did before the problem occurred:
following: Example of Python 3.x + Tensorflow v2 series except in virtual env and without using sudo with pip3 commands.

[Required] Error message:

  ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.5.0 requires numpy~=1.19.2, but you have numpy 1.20.3 which is incompatible.

[Required] Overview of problems and questions:
Following the above section from README will land user with broken tensorflow installation, because .whl file replaces numpy
1.20.3 installed manualy with numpy 1.19.5.

Installing collected packages: numpy, tensorflow
  Attempting uninstall: numpy
    Found existing installation: numpy 1.20.3
    Uninstalling numpy-1.20.3:
      Successfully uninstalled numpy-1.20.3
Successfully installed numpy-1.19.5 tensorflow-2.5.0

in order to fix this user has to reinstall numpy which produces the misleading pip error message.

$ pip3 install install numpy==1.20.3
Looking in indexes: https://pypi.org/simple, https://www.piwheels.org/simple
Collecting install
  Downloading https://www.piwheels.org/simple/install/install-1.3.4-py3-none-any.whl (3.1 kB)
Collecting numpy==1.20.3
  Using cached https://www.piwheels.org/simple/numpy/numpy-1.20.3-cp37-cp37m-linux_armv7l.whl (11.6 MB)
Installing collected packages: numpy, install
  Attempting uninstall: numpy
    Found existing installation: numpy 1.19.5
    Uninstalling numpy-1.19.5:
      Successfully uninstalled numpy-1.19.5
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.5.0 requires numpy~=1.19.2, but you have numpy 1.20.3 which is incompatible.

Question on performance

Device: RaspberryPi3+:

CPU architecture: armv7l

OS: Raspbian stretch

Conditions
I installed tensor flow lite for python 3.5, multi threaded (this wheel: tensorflow-1.14.0-cp35-cp35m-linux_armv7l.whl).
I run the test as documented in the readme.

Question
I cannot get better timing than 0.27 seconds with num_threads=4 while the readme reports a much faster 0.16.
Any hints?
Thanks!

Raspberry pi -> install tensorflow -> error

pi@raspberrypi:~/tensorflow $ ./configure
WARNING: Running Bazel server needs to be killed, because the startup options are different.
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.19.2- (@non-git) installed.
Please upgrade your bazel installation to version 0.24.1 or higher to build TensorFlow!

Python 3.9 build

[Required] Your device (RaspberryPi3, LaptopPC, or other device name): Raspberry Pi 4B

[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name): aarch64

[Required] Your OS (Raspbian, Ubuntu1604, or other os name): Debian 11 Bullseye

[Required] Details of the work you did before the problem occurred: There is no build for python 3.9

[Required] Overview of problems and questions:

Ubuntu Hirsute
https://packages.ubuntu.com/hirsute/python3
and Debian 11 Bullseye
https://packages.debian.org/ja/bullseye/python3
now have python3.9 as the default python. Is there any plan to build binaries for python 3.9?
TF 2.4 and earlier do not support Python 3.9. Only 2.5 supports it.

TF Lite problem in tensorflow-2.3.0-cp38-none-linux_aarch64

[Required] Your device (RaspberryPi3, LaptopPC, or other device name):

Raspberry Pi 4B 8GB running Ubuntu Mate 20.04 64bit

[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name):

AArch64 (ARM 64bit)

[Required] Your OS (Raspbian, Ubuntu1604, or other os name):

Ubuntu Mate 20.04 64bit

[Required] Details of the work you did before the problem occurred:

My own program (based on Google's example) at

https://qiita.com/kakinaguru_zo/items/c875ca7452c30a22289d#tensorflow-lite%E3%81%AE%E3%82%B5%E3%83%B3%E3%83%97%E3%83%AB%E3%81%AE%E6%94%B9%E9%80%A0%E7%89%88

works fine on both

But it fails with TensorFlow 2.3 from here on Ubuntu 20.04 arm64.
TensorFlow 2.3 from here on Ubuntu 20.04 can perfectly execute examples on https://qiita.com/karaage0703/items/8c3197d11f61812546a9 which do not use TF Lite.

[Required] Error message:

emojifreak@raspi-mate:~/tensorflow-lite$ python3 detect_usbcamera.py --model detect_usbcamera.py --labels coco_labels.txt
Traceback (most recent call last):
File "detect_usbcamera.py", line 158, in
main()
File "detect_usbcamera.py", line 124, in main
interpreter = tflite.Interpreter(args.model)
File "/usr/local/lib/python3.8/dist-packages/tensorflow/lite/python/interpreter.py", line 197, in init
_interpreter_wrapper.CreateWrapperFromFile(
ValueError: Model provided has model identifier 'thon', should be 'TFL3'

[Required] Overview of problems and questions:

TFLite included in tensorflow-2.3.0-cp38-none-linux_aarch64 seems somewhat defferent from Google's version and lhelontra's version.

Request - TF1.15

[Required] Your device (RaspberryPi3, LaptopPC, or other device name):

[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name):

[Required] Your OS (Raspbian, Ubuntu1604, or other os name):

[Required] Details of the work you did before the problem occurred:






[Required] Error message:






[Required] Overview of problems and questions:

Hello, please note this is not an issue. It rather a new piwheel package for tensorflow 1.15.
Could you provide a pre-built package for the nightly build of TF?
Thanks

This is important, when it comes to deploying deep learning models on edgetpu devices, you need tf2 or tf1.15. The issue with tf2 is this: tflite_runtime package is only tested with tf1.15 and there are a lot of issues with tf2.


Permission denied

[Required] Your device (RaspberryPi3, LaptopPC, or other device name): RaspberryPi3

[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name): armv7l

[Required] Your OS (Raspbian, Ubuntu1604, or other os name): Raspbian

[Required] Details of the work you did before the problem occurred: install tensorflow

[Required] Error message: bash: ./tensorflow-2.2.0-cp37-cp37m-linux_armv7l_download.sh: Permission denied

[Required] Overview of problems and questions: its not working

ImportError: cannot import name 'cloud'

Thank you for your great work.
I followed your instruction and installed tensorflow-2.0.0a0-cp35-cp35m-linux_armv7l.whl.
Then I run the command below.

python3 label_image.py --num_threads 4 --image grace_hopper.bmp --model_file mobilenet_v1_1.0_224_quant.tflite --label_file labels.txt

I got an error below.

Traceback (most recent call last):
File "label_image.py", line 7, in
from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper
File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/init.py", line 31, in
from tensorflow.contrib import cloud
ImportError: cannot import name 'cloud'

I use RaspberryPi Model B+and Raspbian 2018-11-13-raspbian-stretch image.

Your kind advice would be very helpful.

Question: will XNNPACK compile flag be set in future builds?

The TensorFlow/Google Research team just released following blog: https://blog.tensorflow.org/2020/07/accelerating-tensorflow-lite-xnnpack-integration.html, outlining a significant performance improvement with the XNNPACK library. For Linux, this would require to set an additional flag during the Bazel build: --define tflite_with_xnnpack=true

Is this something that will be considered to be included in future versions of your TF builds for the raspberry pi?

BTW - I'd like to thank you for your work to publish the TF builds for rpi!

Thanks,
Filip

Found existing installation: wrapt 1.10.11 (When install tensorflow 2.5

[Required] Your device (RaspberryPi3, LaptopPC, or other device name): Raspberry Pi3

[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name): i don't know

[Required] Your OS (Raspbian, Ubuntu1604, or other os name): Raspbian

[Required] Details of the work you did before the problem occurred: Install tensorflow Lite, with follow this instructure
2021-06-20-095951_1440x900_scrot

[Required] Error message: Installing collected packages: wrapt, termcolor, tensorflow-estimator, tensorboard, opt-einsum, keras-preprocessing, keras-nightly, h5py, google-pasta, gast, flatbuffers, astunparse, tensorflow
Attempting uninstall: wrapt
Found existing installation: wrapt 1.10.11
ERROR: Cannot uninstall 'wrapt'. It is a distutils installed project and thus we cannot accurately determine which files belong to it which would lead to only a partial uninstall.
2021-06-20-100359_1440x900_scrot

[Required] Overview of problems and questions: How i solve this? because i still can'nt install library tensorflow as this follow

2021-06-20-100300_1440x900_scrot

ARMv6 support missing

The installation steps for using tf under Buster, as outlined in the documentation, do not work. Any plans to extend this invaluable endeavor to the following?

  • ARMv6-compatible processor rev 7 (v6l)
  • Linux 4.19.66+ armv6l GNU/Linux (Buster)

Kind regards.

Is it really multi-threaded?

[Required] Your device (RaspberryPi3, LaptopPC, or other device name):
Raspberry Pi 3B+

[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name):
armv7l

[Required] Your OS (Raspbian, Ubuntu1604, or other os name):
Raspbian Stretch

[Required] Details of the work you did before the problem occurred:
I just followed the instructions you mentioned step-by-step

[Required] Error message:
There is no error message.

[Required] Overview of problems and questions:
I'm using htop to see how many cores are used. setting --num_threads=1 works. It uses one single core. But setting --num_threads=4 doesn't show that 4 cores are used! Again one single core is used.

Any chance to get aarch64 for 1.13.1 too?

[Required] RaspberryPi3 / Jetson Nano

[Required] armv7l and aarch64

[Required] Raspbian, Ubuntu Tegra

[Required] Overview of problems and questions:

I have a current binding against 1.13.1. I see there is only support for armv7l. But I would love to have to aarch64 too. I see 1.14.1 does have support for both, but it would imply upgrading the binding I am using.

Any chance for this to happen?

Do i need 3.7.3 ?

[Required] Your device (RaspberryPi3, LaptopPC, or other device name):
RaspberryPi4

[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name):
armv7

[Required] Your OS (Raspbian, Ubuntu1604, or other os name):
Rasbian Buster

[Required] Details of the work you did before the problem occurred:
Expected wheel to install

[Required] Error message:
(serval) pi@serval-001:~ $ pip3 install tensorflow-1.15.0-cp37-cp37m-linux_armv7l.whl
ERROR: tensorflow-1.15.0-cp37-cp37m-linux_armv7l.whl is not a supported wheel on this platform.
(serval) pi@serval-001:~ $ python3
Python 3.7.7 (default, Apr 1 2020, 10:55:29)
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.

[Required] Overview of problems and questions:
On buster current version of python3 is 3.8.x
I installed python3.7 side by side and setup a virtual env with python3.7.7

Do I specifically need 3.7.3 for this wheel to install or do I miss something else here?

Please provide FlexDelegate fix for arm71 in tf 2.3

[Required] Your device (RaspberryPi3, LaptopPC, or other device name): RPI 4

[Required] Your device's CPU architecture (armv7l, x86_64, or other architecture name): armv71

[Required] Your OS (Raspbian, Ubuntu1604, or other os name): Raspbian

[Required] Details of the work you did before the problem occurred: Tried the 2.2 armv71 wheel but FlexDelegate fix still not included

[Required] Error message: This is what I get when trying to load tflite model while running armv71 from 2.2: RuntimeError: Regular TensorFlow ops are not supported by this interpreter. Make sure you apply/link the Flex delegate before inference.Node number 0 (FlexConv3D) failed to prepare.

[Required] Overview of problems and questions: I've gotten this working with tf-nightly 2.3 on my desktop but need a rpi4 friendly whl. I assume you are probably planning to do this given your arch wheels but just thought I'd post the issue anyways.

GLIBC_2.28' not found

device name: RaspberryPi 4

architecture name: aarch64

OS: Ubuntu 18.04

Python version: 3.7.5

Details of the work you did before the problem occurred: I tried different versions of tensorflow that you provided (very much appreciated). When I try to import tensorflow the error shows up. Also, the tflite_runtime module is not installed this way, so I tried installing it from here. But the problem is with the set_num_threads method. In 2.x tensorflow Interpreter has no module for this and because of this I cannot have the full utilization of the cpu.

Error message:

ImportError: /lib/aarch64-linux-gnu/libc.so.6: version `GLIBC_2.28' not found (required by /usr/local/lib/python3.7/dist-packages/tensorflow_core/python/_pywrap_tensorflow_internal.so)

AttributeError: 'Interpreter' object has no attribute 'set_num_threads'

My GLIBC version is 2.27 from ldd --version. I searched about upgrading it but I found it is risky and it may corrupt my other binary files.

Note: I reached 10fps on raspberry pi3 on ubuntu 19.04 but the same image of it would not boot on raspberry pi4 so I tried ubuntu 18.04 on raspberry 4 instead but I now have this problem and could only get less than 6 fps. I also could reach 10fps on this raspberry 4 using python 2.7.

So, I think my problem could be solved by 3 ways:

  1. I get a version of tensorflow which do not need GLIB 2.28 on aarch64.

  2. I get a version of tflite_runtime which have set_num_thread method.

  3. Find out how to fully utilize the cpu cores using tensorflow 2.x.

It would be very frustrating for me to change the OS again and install everything again.

Could you help me with any of these or give me solutions?

I'm also curious about why Tensorflow has removed the set_num_threads method in version 2.x. Do you have idea about it?

Thanks for the great work.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.