Git Product home page Git Product logo

attention-sampling's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

attention-sampling's Issues

MNIST noise overlaps signal

Hi,

when I create a dataset with make_mnist.py and size 500x500, I figured that noise is sometimes overlapping numbers.

image

This can be resolved by changing the order of signal/noise insertion, e.g. the high-dimensional image:

def high(self):
  ...
  if self._dataset._should_add_noise:
    for p, i in self.noise_positions_and_patterns:
        high[self._get_slice(p)] = \
            255*self._dataset._noise[i]
                  
  for p, i in zip(self._positions, self._idxs):
      high[self._get_slice(p)] = \
          255*self._dataset._images[i]

instead of

def high(self):
  ...             
  for p, i in zip(self._positions, self._idxs):
      high[self._get_slice(p)] = \
          255*self._dataset._images[I]
   
  if self._dataset._should_add_noise:
    for p, i in self.noise_positions_and_patterns:
        high[self._get_slice(p)] = \
            255*self._dataset._noise[i]

Let me know in case I missed something.

What is the role of "receptive field"?

Hi guys,

I am working on developing your algorithm. I am wondering what is the role of "receptive field" in your code? Why we need this to shift our sampling offset? Could you tell me your intuition about tuning this parameters? Thanks.

pip install runtime error: Couldn't compile and install ats.ops.extract_patches.libpatches

On Ubuntu 19.04
Python 3.7.3

pip3 install attention-sampling
Collecting attention-sampling
  Using cached https://files.pythonhosted.org/packages/ff/2d/4474d1f516865eb83419c67045d885adc03266f97858fa8eeb14117c709d/attention-sampling-0.2.tar.gz
Requirement already satisfied: keras>=2 in /home/jaiczay/.local/lib/python3.7/site-packages (from attention-sampling) (2.2.4)
Requirement already satisfied: numpy in /home/jaiczay/.local/lib/python3.7/site-packages (from attention-sampling) (1.16.4)
Requirement already satisfied: pyyaml in /usr/lib/python3/dist-packages (from keras>=2->attention-sampling) (3.13)
Requirement already satisfied: six>=1.9.0 in /usr/lib/python3/dist-packages (from keras>=2->attention-sampling) (1.12.0)
Requirement already satisfied: h5py in /home/jaiczay/.local/lib/python3.7/site-packages (from keras>=2->attention-sampling) (2.9.0)
Requirement already satisfied: keras-preprocessing>=1.0.5 in /home/jaiczay/.local/lib/python3.7/site-packages (from keras>=2->attention-sampling) (1.1.0)
Requirement already satisfied: scipy>=0.14 in /home/jaiczay/.local/lib/python3.7/site-packages (from keras>=2->attention-sampling) (1.3.1)
Requirement already satisfied: keras-applications>=1.0.6 in /home/jaiczay/.local/lib/python3.7/site-packages (from keras>=2->attention-sampling) (1.0.7)
Building wheels for collected packages: attention-sampling
  Running setup.py bdist_wheel for attention-sampling ... error
  Complete output from command /usr/bin/python3 -u -c "import setuptools, tokenize;__file__='/tmp/pip-install-pbfqhw21/attention-sampling/setup.py';f=getattr(tokenize, 'open', open)(__file__);code=f.read().replace('\r\n', '\n');f.close();exec(compile(code, __file__, 'exec'))" bdist_wheel -d /tmp/pip-wheel-axu0ud85 --python-tag cp37:
  running bdist_wheel
  running build
  running build_py
  creating build
  creating build/lib.linux-x86_64-3.7
  creating build/lib.linux-x86_64-3.7/ats
  copying ats/__init__.py -> build/lib.linux-x86_64-3.7/ats
  creating build/lib.linux-x86_64-3.7/ats/ops
  copying ats/ops/__init__.py -> build/lib.linux-x86_64-3.7/ats/ops
  creating build/lib.linux-x86_64-3.7/ats/utils
  copying ats/utils/layers.py -> build/lib.linux-x86_64-3.7/ats/utils
  copying ats/utils/regularizers.py -> build/lib.linux-x86_64-3.7/ats/utils
  copying ats/utils/training.py -> build/lib.linux-x86_64-3.7/ats/utils
  copying ats/utils/__init__.py -> build/lib.linux-x86_64-3.7/ats/utils
  creating build/lib.linux-x86_64-3.7/ats/core
  copying ats/core/sampling.py -> build/lib.linux-x86_64-3.7/ats/core
  copying ats/core/builder.py -> build/lib.linux-x86_64-3.7/ats/core
  copying ats/core/expectation.py -> build/lib.linux-x86_64-3.7/ats/core
  copying ats/core/ats_layer.py -> build/lib.linux-x86_64-3.7/ats/core
  copying ats/core/__init__.py -> build/lib.linux-x86_64-3.7/ats/core
  creating build/lib.linux-x86_64-3.7/ats/data
  copying ats/data/from_tensors.py -> build/lib.linux-x86_64-3.7/ats/data
  copying ats/data/base.py -> build/lib.linux-x86_64-3.7/ats/data
  copying ats/data/__init__.py -> build/lib.linux-x86_64-3.7/ats/data
  creating build/lib.linux-x86_64-3.7/ats/ops/extract_patches
  copying ats/ops/extract_patches/__init__.py -> build/lib.linux-x86_64-3.7/ats/ops/extract_patches
  running build_ext
  Building ats.ops.extract_patches.libpatches
  -- The C compiler identification is GNU 8.3.0
  -- The CXX compiler identification is GNU 8.3.0
  -- Check for working C compiler: /usr/bin/cc
  -- Check for working C compiler: /usr/bin/cc -- works
  -- Detecting C compiler ABI info
  -- Detecting C compiler ABI info - done
  -- Detecting C compile features
  -- Detecting C compile features - done
  -- Check for working CXX compiler: /usr/bin/c++
  -- Check for working CXX compiler: /usr/bin/c++ -- works
  -- Detecting CXX compiler ABI info
  -- Detecting CXX compiler ABI info - done
  -- Detecting CXX compile features
  -- Detecting CXX compile features - done
  CUDA_TOOLKIT_ROOT_DIR not found or specified
  -- Could NOT find CUDA (missing: CUDA_TOOLKIT_ROOT_DIR CUDA_NVCC_EXECUTABLE CUDA_INCLUDE_DIRS CUDA_CUDART_LIBRARY)
    File "<string>", line 1
      import tensorflow as tf;         print(tf.sysconfig.get_include(), end='')
                                                                            ^
  SyntaxError: invalid syntax
  -- Found TensorFlow include:
    File "<string>", line 1
      import tensorflow as tf;         print(tf.sysconfig.get_lib(), end='')
                                                                        ^
  SyntaxError: invalid syntax
    File "<string>", line 1
      import tensorflow as tf;         lib = next(f for f in tf.sysconfig.get_link_flags()                    if f.startswith('-l'));         print(lib[2:] if lib[2] != ':' else lib[3:], end='')
                                                                                                                                                                                              ^
  SyntaxError: invalid syntax
  -- Found TensorFlow lib: TENSORFLOW_LIB-NOTFOUND
    File "<string>", line 1
      import tensorflow as tf;                 print(tf.__version__, end='')
                                                                        ^
  SyntaxError: invalid syntax
  CMake Warning at /tmp/pip-install-pbfqhw21/attention-sampling/ats/ops/cmake/FindTensorFlow.cmake:60 (message):
    You are using an unsupported compiler! It is recommended to build your
    TensorFlow extensions with gcc 4.8 .
  Call Stack (most recent call first):
    CMakeLists.txt:14 (find_package)
  
  
    File "<string>", line 1
      import tensorflow as tf;         print(' '.join(             f for f in tf.sysconfig.get_compile_flags()             if not f.startswith('-I')), end='')
                                                                                                                                                          ^
  SyntaxError: invalid syntax
  -- Added TensorFlow flags:
  CMake Error at /usr/share/cmake-3.13/Modules/FindPackageHandleStandardArgs.cmake:137 (message):
    Could NOT find TensorFlow (missing: TENSORFLOW_INCLUDE_DIR TENSORFLOW_LIB)
  Call Stack (most recent call first):
    /usr/share/cmake-3.13/Modules/FindPackageHandleStandardArgs.cmake:378 (_FPHSA_FAILURE_MESSAGE)
    /tmp/pip-install-pbfqhw21/attention-sampling/ats/ops/cmake/FindTensorFlow.cmake:82 (find_package_handle_standard_args)
    CMakeLists.txt:14 (find_package)
  
  
  -- Configuring incomplete, errors occurred!
  See also "/tmp/pip-install-pbfqhw21/attention-sampling/ats/ops/extract_patches/build/CMakeFiles/CMakeOutput.log".
  Traceback (most recent call last):
    File "/tmp/pip-install-pbfqhw21/attention-sampling/setup.py", line 33, in compile
      cwd=extension_dir
    File "/usr/lib/python3.7/subprocess.py", line 347, in check_call
      raise CalledProcessError(retcode, cmd)
  subprocess.CalledProcessError: Command '['cmake', '-DCMAKE_BUILD_TYPE=Release', '..']' returned non-zero exit status 1.
  
  The above exception was the direct cause of the following exception:
  
  Traceback (most recent call last):
    File "<string>", line 1, in <module>
    File "/tmp/pip-install-pbfqhw21/attention-sampling/setup.py", line 144, in <module>
      setup_package()
    File "/tmp/pip-install-pbfqhw21/attention-sampling/setup.py", line 139, in setup_package
      cmdclass={"build_ext": custom_build_ext}
    File "/home/jaiczay/.local/lib/python3.7/site-packages/setuptools/__init__.py", line 145, in setup
      return distutils.core.setup(**attrs)
    File "/usr/lib/python3.7/distutils/core.py", line 148, in setup
      dist.run_commands()
    File "/usr/lib/python3.7/distutils/dist.py", line 966, in run_commands
      self.run_command(cmd)
    File "/usr/lib/python3.7/distutils/dist.py", line 985, in run_command
      cmd_obj.run()
    File "/usr/lib/python3/dist-packages/wheel/bdist_wheel.py", line 188, in run
      self.run_command('build')
    File "/usr/lib/python3.7/distutils/cmd.py", line 313, in run_command
      self.distribution.run_command(command)
    File "/usr/lib/python3.7/distutils/dist.py", line 985, in run_command
      cmd_obj.run()
    File "/usr/lib/python3.7/distutils/command/build.py", line 135, in run
      self.run_command(cmd_name)
    File "/usr/lib/python3.7/distutils/cmd.py", line 313, in run_command
      self.distribution.run_command(command)
    File "/usr/lib/python3.7/distutils/dist.py", line 985, in run_command
      cmd_obj.run()
    File "/home/jaiczay/.local/lib/python3.7/site-packages/setuptools/command/build_ext.py", line 78, in run
      _build_ext.run(self)
    File "/usr/lib/python3.7/distutils/command/build_ext.py", line 340, in run
      self.build_extensions()
    File "/usr/lib/python3.7/distutils/command/build_ext.py", line 449, in build_extensions
      self._build_extensions_serial()
    File "/usr/lib/python3.7/distutils/command/build_ext.py", line 474, in _build_extensions_serial
      self.build_extension(ext)
    File "/tmp/pip-install-pbfqhw21/attention-sampling/setup.py", line 52, in build_extension
      ext.compile()
    File "/tmp/pip-install-pbfqhw21/attention-sampling/setup.py", line 46, in compile
      ) from e
  RuntimeError: Couldn't compile and install ats.ops.extract_patches.libpatches
  
  ----------------------------------------
  Failed building wheel for attention-sampling
  Running setup.py clean for attention-sampling
Failed to build attention-sampling
Installing collected packages: attention-sampling
  Running setup.py install for attention-sampling ... error
    Complete output from command /usr/bin/python3 -u -c "import setuptools, tokenize;__file__='/tmp/pip-install-pbfqhw21/attention-sampling/setup.py';f=getattr(tokenize, 'open', open)(__file__);code=f.read().replace('\r\n', '\n');f.close();exec(compile(code, __file__, 'exec'))" install --record /tmp/pip-record-n262zdr0/install-record.txt --single-version-externally-managed --compile --user --prefix=:
    running install
    running build
    running build_py
    creating build
    creating build/lib.linux-x86_64-3.7
    creating build/lib.linux-x86_64-3.7/ats
    copying ats/__init__.py -> build/lib.linux-x86_64-3.7/ats
    creating build/lib.linux-x86_64-3.7/ats/ops
    copying ats/ops/__init__.py -> build/lib.linux-x86_64-3.7/ats/ops
    creating build/lib.linux-x86_64-3.7/ats/utils
    copying ats/utils/layers.py -> build/lib.linux-x86_64-3.7/ats/utils
    copying ats/utils/regularizers.py -> build/lib.linux-x86_64-3.7/ats/utils
    copying ats/utils/training.py -> build/lib.linux-x86_64-3.7/ats/utils
    copying ats/utils/__init__.py -> build/lib.linux-x86_64-3.7/ats/utils
    creating build/lib.linux-x86_64-3.7/ats/core
    copying ats/core/sampling.py -> build/lib.linux-x86_64-3.7/ats/core
    copying ats/core/builder.py -> build/lib.linux-x86_64-3.7/ats/core
    copying ats/core/expectation.py -> build/lib.linux-x86_64-3.7/ats/core
    copying ats/core/ats_layer.py -> build/lib.linux-x86_64-3.7/ats/core
    copying ats/core/__init__.py -> build/lib.linux-x86_64-3.7/ats/core
    creating build/lib.linux-x86_64-3.7/ats/data
    copying ats/data/from_tensors.py -> build/lib.linux-x86_64-3.7/ats/data
    copying ats/data/base.py -> build/lib.linux-x86_64-3.7/ats/data
    copying ats/data/__init__.py -> build/lib.linux-x86_64-3.7/ats/data
    creating build/lib.linux-x86_64-3.7/ats/ops/extract_patches
    copying ats/ops/extract_patches/__init__.py -> build/lib.linux-x86_64-3.7/ats/ops/extract_patches
    running build_ext
    Building ats.ops.extract_patches.libpatches
    CUDA_TOOLKIT_ROOT_DIR not found or specified
    -- Could NOT find CUDA (missing: CUDA_TOOLKIT_ROOT_DIR CUDA_NVCC_EXECUTABLE CUDA_INCLUDE_DIRS CUDA_CUDART_LIBRARY)
      File "<string>", line 1
        import tensorflow as tf;         print(tf.sysconfig.get_include(), end='')
                                                                              ^
    SyntaxError: invalid syntax
    -- Found TensorFlow include:
      File "<string>", line 1
        import tensorflow as tf;         print(tf.sysconfig.get_lib(), end='')
                                                                          ^
    SyntaxError: invalid syntax
      File "<string>", line 1
        import tensorflow as tf;         lib = next(f for f in tf.sysconfig.get_link_flags()                    if f.startswith('-l'));         print(lib[2:] if lib[2] != ':' else lib[3:], end='')
                                                                                                                                                                                                ^
    SyntaxError: invalid syntax
    -- Found TensorFlow lib: TENSORFLOW_LIB-NOTFOUND
      File "<string>", line 1
        import tensorflow as tf;                 print(tf.__version__, end='')
                                                                          ^
    SyntaxError: invalid syntax
    CMake Warning at /tmp/pip-install-pbfqhw21/attention-sampling/ats/ops/cmake/FindTensorFlow.cmake:60 (message):
      You are using an unsupported compiler! It is recommended to build your
      TensorFlow extensions with gcc 4.8 .
    Call Stack (most recent call first):
      CMakeLists.txt:14 (find_package)
    
    
      File "<string>", line 1
        import tensorflow as tf;         print(' '.join(             f for f in tf.sysconfig.get_compile_flags()             if not f.startswith('-I')), end='')
                                                                                                                                                            ^
    SyntaxError: invalid syntax
    -- Added TensorFlow flags:
    CMake Error at /usr/share/cmake-3.13/Modules/FindPackageHandleStandardArgs.cmake:137 (message):
      Could NOT find TensorFlow (missing: TENSORFLOW_INCLUDE_DIR TENSORFLOW_LIB)
    Call Stack (most recent call first):
      /usr/share/cmake-3.13/Modules/FindPackageHandleStandardArgs.cmake:378 (_FPHSA_FAILURE_MESSAGE)
      /tmp/pip-install-pbfqhw21/attention-sampling/ats/ops/cmake/FindTensorFlow.cmake:82 (find_package_handle_standard_args)
      CMakeLists.txt:14 (find_package)
    
    
    -- Configuring incomplete, errors occurred!
    See also "/tmp/pip-install-pbfqhw21/attention-sampling/ats/ops/extract_patches/build/CMakeFiles/CMakeOutput.log".
    Traceback (most recent call last):
      File "/tmp/pip-install-pbfqhw21/attention-sampling/setup.py", line 33, in compile
        cwd=extension_dir
      File "/usr/lib/python3.7/subprocess.py", line 347, in check_call
        raise CalledProcessError(retcode, cmd)
    subprocess.CalledProcessError: Command '['cmake', '-DCMAKE_BUILD_TYPE=Release', '..']' returned non-zero exit status 1.
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "<string>", line 1, in <module>
      File "/tmp/pip-install-pbfqhw21/attention-sampling/setup.py", line 144, in <module>
        setup_package()
      File "/tmp/pip-install-pbfqhw21/attention-sampling/setup.py", line 139, in setup_package
        cmdclass={"build_ext": custom_build_ext}
      File "/home/jaiczay/.local/lib/python3.7/site-packages/setuptools/__init__.py", line 145, in setup
        return distutils.core.setup(**attrs)
      File "/usr/lib/python3.7/distutils/core.py", line 148, in setup
        dist.run_commands()
      File "/usr/lib/python3.7/distutils/dist.py", line 966, in run_commands
        self.run_command(cmd)
      File "/usr/lib/python3.7/distutils/dist.py", line 985, in run_command
        cmd_obj.run()
      File "/home/jaiczay/.local/lib/python3.7/site-packages/setuptools/command/install.py", line 61, in run
        return orig.install.run(self)
      File "/usr/lib/python3.7/distutils/command/install.py", line 589, in run
        self.run_command('build')
      File "/usr/lib/python3.7/distutils/cmd.py", line 313, in run_command
        self.distribution.run_command(command)
      File "/usr/lib/python3.7/distutils/dist.py", line 985, in run_command
        cmd_obj.run()
      File "/usr/lib/python3.7/distutils/command/build.py", line 135, in run
        self.run_command(cmd_name)
      File "/usr/lib/python3.7/distutils/cmd.py", line 313, in run_command
        self.distribution.run_command(command)
      File "/usr/lib/python3.7/distutils/dist.py", line 985, in run_command
        cmd_obj.run()
      File "/home/jaiczay/.local/lib/python3.7/site-packages/setuptools/command/build_ext.py", line 78, in run
        _build_ext.run(self)
      File "/usr/lib/python3.7/distutils/command/build_ext.py", line 340, in run
        self.build_extensions()
      File "/usr/lib/python3.7/distutils/command/build_ext.py", line 449, in build_extensions
        self._build_extensions_serial()
      File "/usr/lib/python3.7/distutils/command/build_ext.py", line 474, in _build_extensions_serial
        self.build_extension(ext)
      File "/tmp/pip-install-pbfqhw21/attention-sampling/setup.py", line 52, in build_extension
        ext.compile()
      File "/tmp/pip-install-pbfqhw21/attention-sampling/setup.py", line 46, in compile
        ) from e
    RuntimeError: Couldn't compile and install ats.ops.extract_patches.libpatches
    
    ----------------------------------------
Command "/usr/bin/python3 -u -c "import setuptools, tokenize;__file__='/tmp/pip-install-pbfqhw21/attention-sampling/setup.py';f=getattr(tokenize, 'open', open)(__file__);code=f.read().replace('\r\n', '\n');f.close();exec(compile(code, __file__, 'exec'))" install --record /tmp/pip-record-n262zdr0/install-record.txt --single-version-externally-managed --compile --user --prefix=" failed with error code 1 in /tmp/pip-install-pbfqhw21/attention-sampling/

CMake Error at patches_generated_extract_patches.cu.o.cmake:207

I'm building your ats repository inside of a Docker container based on the latest tensorflow gpu image with this Dockerfile:

FROM tensorflow/tensorflow:latest-gpu-py3

WORKDIR /
RUN apt-get update && \
    apt-get install -y git g++ cmake

RUN git clone https://github.com/idiap/attention-sampling attention-sampling

RUN g++ --version && \
    cmake --version && \
    python --version && \
    python -c "import tensorflow; print(tensorflow.__version__)"

WORKDIR /attention-sampling
RUN pip install -e . && \
    cd ats/ops/extract_patches && \
    mkdir build && \
    cd build && \
    cmake -DCMAKE_BUILD_TYPE=Release .. && \
    make && \
    make install

WORKDIR /attention-sampling
RUN python -m unittest discover -s tests/ 

WORKDIR /

In order to build, I've set the "default-runtime":"nvidia".
Some other version info:

g++ (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
Copyright (C) 2015 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

cmake version 3.5.1

CMake suite maintained and supported by Kitware (kitware.com/cmake).
Python 3.5.2
tensorflow-gpu 1.13.1

An error is encountered at cmake:

-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE  
-- Found CUDA: /usr/local/cuda (found version "10.0") 
-- Found TensorFlow: /usr/local/lib/python3.5/dist-packages/tensorflow/include  
-- Configuring done
-- Generating done
-- Build files have been written to: /attention-sampling/ats/ops/extract_patches/build
[ 33%] Building NVCC (Device) object CMakeFiles/patches.dir/patches_generated_extract_patches.cu.o
In file included from /usr/local/lib/python3.5/dist-packages/tensorflow/include/unsupported/Eigen/CXX11/../../../Eigen/src/Core/util/ConfigureVectorization.h:384:0,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/unsupported/Eigen/CXX11/../../../Eigen/Core:22,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/unsupported/Eigen/CXX11/Tensor:14,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/third_party/eigen3/unsupported/Eigen/CXX11/Tensor:1,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/tensorflow/core/framework/numeric_types.h:20,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/tensorflow/core/framework/allocator.h:23,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/tensorflow/core/framework/op_kernel.h:23,
                 from /attention-sampling/ats/ops/extract_patches/extract_patches.cu:12:
/usr/local/cuda/include/host_defines.h:54:2: warning: #warning "host_defines.h is an internal header file and must not be used directly.  This file will be removed in a future CUDA release.  Please use cuda_runtime_api.h or cuda_runtime.h instead." [-Wcpp]
 #warning "host_defines.h is an internal header file and must not be used directly.  This file will be removed in a future CUDA release.  Please use cuda_runtime_api.h or cuda_runtime.h instead."
  ^
In file included from /usr/include/c++/5/atomic:38:0,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/tensorflow/core/lib/core/refcount.h:19,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/tensorflow/core/platform/tensor_coding.h:21,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/tensorflow/core/framework/resource_handle.h:19,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/tensorflow/core/framework/allocator.h:24,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/tensorflow/core/framework/op_kernel.h:23,
                 from /attention-sampling/ats/ops/extract_patches/extract_patches.cu:12:
/usr/include/c++/5/bits/c++0x_warning.h:32:2: error: #error This file requires compiler and library support for the ISO C++ 2011 standard. This support must be enabled with the -std=c++11 or -std=gnu++11 compiler options.
 #error This file requires compiler and library support \
  ^
In file included from /usr/local/lib/python3.5/dist-packages/tensorflow/include/absl/base/config.h:66:0,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/absl/strings/string_view.h:31,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/tensorflow/core/lib/core/stringpiece.h:29,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/tensorflow/core/platform/tensor_coding.h:22,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/tensorflow/core/framework/resource_handle.h:19,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/tensorflow/core/framework/allocator.h:24,
                 from /usr/local/lib/python3.5/dist-packages/tensorflow/include/tensorflow/core/framework/op_kernel.h:23,
                 from /attention-sampling/ats/ops/extract_patches/extract_patches.cu:12:
/usr/local/lib/python3.5/dist-packages/tensorflow/include/absl/base/policy_checks.h:77:2: error: #error "C++ versions less than C++11 are not supported."
 #error "C++ versions less than C++11 are not supported."
  ^
CMake Error at patches_generated_extract_patches.cu.o.cmake:207 (message):
  Error generating
  /attention-sampling/ats/ops/extract_patches/build/CMakeFiles/patches.dir//./patches_generated_extract_patches.cu.o


make[2]: *** [CMakeFiles/patches.dir/patches_generated_extract_patches.cu.o] Error 1
CMakeFiles/patches.dir/build.make:63: recipe for target 'CMakeFiles/patches.dir/patches_generated_extract_patches.cu.o' failed
make[1]: *** [CMakeFiles/patches.dir/all] Error 2
CMakeFiles/Makefile2:67: recipe for target 'CMakeFiles/patches.dir/all' failed
make: *** [all] Error 2
Makefile:127: recipe for target 'all' failed

I am confused because I see in ats/ops/extract_patches/CMakeLists.txt you set(CMAKE_CXX_STANDARD 11), but the error is related to "C++ versions less than C++11 are not supported."? I'm not well versed in c++, is this something you've run into previously? How can I move forward?

CPU version compiles and tests successfully.

Segmentation fault (core dumped)

Hi,
@angeloskath Thank you and your team for creating this library.
I am trying to run the example mnist program when I encountered this problem.

root@4d9b40a6f414:/vol/attention-sampling# ./mnist.py ./test_mnist/mnist-small ./test_mnist/mnist-experiment
2023-03-25 02:57:58.914902: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
WARNING:tensorflow:Deprecation warnings have been disabled. Set TF_ENABLE_DEPRECATION_WARNINGS=1 to re-enable them.
/vol/attention-sampling/ats/ops/extract_patches
Loaded dataset with the following parameters
{
    "n_train": 5000,
    "n_test": 1000,
    "width": 500,
    "height": 500,
    "scale": 0.2,
    "noise": false,
    "seed": 0
}
Segmentation fault (core dumped)

I tracked down the problem and it appears that the problem is associated with libpatches.so which was built when I install the library.

My environment:

Package                          Version        Location
-------------------------------- -------------- -----------------------
absl-py                          0.10.0
astor                            0.8.1
async-generator                  1.10
attention-sampling               0.2
attrs                            20.2.0
audioread                        2.1.8
backcall                         0.2.0
bleach                           3.2.1
certifi                          2020.6.20
cffi                             1.14.3
chardet                          3.0.4
cloudpickle                      1.6.0
contextlib2                      0.6.0.post1
cupy                             8.0.0rc1
cycler                           0.10.0
Cython                           0.29.21
dataclasses                      0.7
decorator                        4.4.2
defusedxml                       0.6.0
DLLogger                         0.1.0          /workspace/src/dllogger
entrypoints                      0.3
fastrlock                        0.5
future                           0.18.2
gast                             0.2.2
google-pasta                     0.2.0
googledrivedownloader            0.4
graphsurgeon                     0.4.5
grpcio                           1.32.0
h5py                             2.10.0
horovod                          0.20.0
html2text                        2020.1.16
idna                             2.10
imageio                          2.15.0
importlib-metadata               2.0.0
iniconfig                        1.1.1
ipykernel                        5.3.4
ipython                          7.16.1
ipython-genutils                 0.2.0
jedi                             0.17.2
Jinja2                           2.11.2
joblib                           0.14.0
json5                            0.9.5
jsonschema                       3.2.0
jupyter-client                   6.1.7
jupyter-core                     4.6.3
jupyter-tensorboard              0.2.0
jupyterlab                       1.2.14
jupyterlab-pygments              0.1.2
jupyterlab-server                1.2.0
jupytext                         1.6.0
Keras                            2.3.1
Keras-Applications               1.0.8
Keras-Preprocessing              1.1.2
kiwisolver                       1.2.0
librosa                          0.7.1
llvmlite                         0.30.0
Markdown                         3.3.1
markdown-it-py                   0.5.5
MarkupSafe                       1.1.1
matplotlib                       3.1.1
mistune                          0.8.4
mock                             3.0.5
mpi4py                           3.0.3
munch                            2.5.0
nbclient                         0.5.0
nbconvert                        6.0.7
nbformat                         5.0.7
nest-asyncio                     1.4.1
networkx                         2.5
nibabel                          3.1.1
nltk                             3.4.5
notebook                         6.0.3
numba                            0.46.0
numpy                            1.17.3
nvidia-dali-cuda110              0.26.0
nvidia-dali-tf-plugin-cuda110    0.26.0
nvidia-tensorboard               1.15.0+nv20.10
nvidia-tensorboard-plugin-dlprof 0.8
nvtx-plugins                     0.1.8
onnx                             1.7.0
opt-einsum                       3.3.0
packaging                        20.4
pandas                           0.25.3
pandocfilters                    1.4.2
parso                            0.7.1
pexpect                          4.7.0
pickleshare                      0.7.5
Pillow                           6.2.1
pip                              20.2.3
pluggy                           0.13.1
portalocker                      2.0.0
portpicker                       1.3.1
progressbar                      2.5
prometheus-client                0.8.0
prompt-toolkit                   3.0.8
protobuf                         3.13.0
psutil                           5.7.0
ptyprocess                       0.6.0
py                               1.9.0
pycocotools                      2.0.0
pycparser                        2.20
Pygments                         2.7.1
pyparsing                        2.4.7
pyrsistent                       0.17.3
pytest                           6.1.1
python-dateutil                  2.8.1
python-speech-features           0.6
pytz                             2020.1
PyWavelets                       1.1.1
PyYAML                           5.3.1
pyzmq                            19.0.2
requests                         2.24.0
resampy                          0.2.2
sacrebleu                        1.2.10
scikit-image                     0.17.2
scikit-learn                     0.23.2
scipy                            1.3.1
Send2Trash                       1.5.0
sentencepiece                    0.1.83
setuptools                       50.3.0
SimpleITK                        1.1.0
six                              1.13.0
SoundFile                        0.10.3.post1
tensorboard                      1.15.9999+nv
tensorflow                       1.15.4+nv
tensorflow-estimator             1.15.1
tensorrt                         7.2.1.4
termcolor                        1.1.0
terminado                        0.9.1
testpath                         0.4.4
tf2onnx                          1.7.1
threadpoolctl                    2.1.0
tifffile                         2020.9.3
toml                             0.10.1
toposort                         1.5
tornado                          6.0.4
tqdm                             4.50.2
traitlets                        4.3.3
typing                           3.7.4.3
typing-extensions                3.7.4.3
uff                              0.6.9
urllib3                          1.25.10
wcwidth                          0.2.5
webencodings                     0.5.1
Werkzeug                         1.0.1
wheel                            0.35.1
wrapt                            1.12.1
zipp                             3.3.0

g++ and gcc==4.8.5
Ubuntu version 18.04

Any help will be deeply appreciated. Thanks in advance!!

Allow use of a patch generator

When working with very big images, sometimes using a function generating the patches instead of passing the whole high resolution image can be very memory convenient. An example is reading from .tiff files: the package openslide has a very convenient read_region function that can be used to return a patch from an image.

It could be a nice feature to have the option of passing a function generating patches instead of x_high as an input. I am not sure if the current structure of the code would allow for it easily, I am still wrapping my head around it.

Suggestion of Environment (OS, package version, etc.)

Hi,

I just posted an issue, but I think I can use another one parallelly.

Can you suggest the simplest working environment for this library?
For the OS, I can use:

  • Windows 10
  • Windows 10 > PowerShell > WSL2 (Ubuntu 18.04 LTS)
  • VirtualBox

Please specify the versions of packages such as:

  • Python
  • CMake
  • gcc, g++
  • TensorFlow (I bet this is 1.13.1)
  • and any other things that I may have missed

Thanks in advance.

C++ versions less than C++11 are not supported

Hi,

I am having trouble building this library. It seems that C++ versions less than C++11 are not supported is the key to this error but I am not sure.

Below is the environment I am running this and the entire output when running python3 setup.py build.

  • OS: Windows 10 > PowerShell > WSL2 (Ubuntu 18.04 LTS)
  • Python 3.6.9
  • TensorFlow 1.13.1
  • CMake 3.19.0-rc3
  • G++: 4.8.5 (checked by adding message(STATUS "G++ Version: " ${CMAKE_CXX_COMPILER_VERSION}) in ats/ops/cmake/FindTensorFlow.cmake.

Below are some changes I have made to make sure they run properly, which should not interfere with the main procedure.

  • I ran python3 setup.py build instead of python setup.py build and also replaced every python -c with python3 -c in ats/ops/cmake/FindTensorFlow.cmake because running python in WSL runs Python 2. Adding alias python=python3 in ~/.bash_aliases let me run python setup.py build but the python -c lines in ats/ops/cmake/FindTensorFlow.cmake still runs Python 2 even after setting this alias. The OS suggested not to erase Python 2 so I followed the advice.
  • In ats/ops/cmake/FindTensorFlow.cmake, I added
    import warnings; \ warnings.simplefilter(action='ignore', category=FutureWarning); \
    before every import tensorflow as tf; \
    because the import of tensorflow shows a lot of FutureWarnings which annoys my vision.

Below is the entire output when I run python3 setup.py build.

june@DESKTOP-7JTR782:~/as2/attention-sampling-02$ python3 setup.py build
running build
running build_py
running build_ext
Building ats.ops.extract_patches.libpatches
-- Found TensorFlow include: /home/june/.local/lib/python3.6/site-packages/tensorflow/include
-- Found TensorFlow lib: /home/june/.local/lib/python3.6/site-packages/tensorflow/libtensorflow_framework.so
-- Added TensorFlow flags: -D_GLIBCXX_USE_CXX11_ABI=0
-- Configuring done
-- Generating done
-- Build files have been written to: /home/june/as2/attention-sampling-02/ats/ops/extract_patches/build
[ 33%] Building NVCC (Device) object CMakeFiles/patches.dir/patches_generated_extract_patches.cu.o
In file included from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/unsupported/Eigen/CXX11/../../../Eigen/src/Core/util/ConfigureVectorization.h:384:0,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/unsupported/Eigen/CXX11/../../../Eigen/Core:22,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/unsupported/Eigen/CXX11/Tensor:14,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/third_party/eigen3/unsupported/Eigen/CXX11/Tensor:1,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/tensorflow/core/framework/numeric_types.h:20,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/tensorflow/core/framework/allocator.h:23,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/tensorflow/core/framework/op_kernel.h:23,
                 from /home/june/as2/attention-sampling-02/ats/ops/extract_patches/extract_patches.cu:12:
/usr/local/cuda/include/host_defines.h:54:2: warning: #warning "host_defines.h is an internal header file and must not be used directly.  This file will be removed in a future CUDA release.  Please use cuda_runtime_api.h or cuda_runtime.h instead." [-Wcpp]
 #warning "host_defines.h is an internal header file and must not be used directly.  This file will be removed in a future CUDA release.  Please use cuda_runtime_api.h or cuda_runtime.h instead."
  ^
In file included from /usr/include/c++/4.8/atomic:38:0,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/tensorflow/core/lib/core/refcount.h:19,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/tensorflow/core/platform/tensor_coding.h:21,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/tensorflow/core/framework/resource_handle.h:19,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/tensorflow/core/framework/allocator.h:24,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/tensorflow/core/framework/op_kernel.h:23,
                 from /home/june/as2/attention-sampling-02/ats/ops/extract_patches/extract_patches.cu:12:
/usr/include/c++/4.8/bits/c++0x_warning.h:32:2: error: #error This file requires compiler and library support for the ISO C++ 2011 standard. This support is currently experimental, and must be enabled with the -std=c++11 or -std=gnu++11 compiler options.
 #error This file requires compiler and library support for the \
  ^
In file included from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/absl/base/config.h:66:0,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/absl/strings/string_view.h:31,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/tensorflow/core/lib/core/stringpiece.h:29,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/tensorflow/core/platform/tensor_coding.h:22,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/tensorflow/core/framework/resource_handle.h:19,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/tensorflow/core/framework/allocator.h:24,
                 from /home/june/.local/lib/python3.6/site-packages/tensorflow/include/tensorflow/core/framework/op_kernel.h:23,
                 from /home/june/as2/attention-sampling-02/ats/ops/extract_patches/extract_patches.cu:12:
/home/june/.local/lib/python3.6/site-packages/tensorflow/include/absl/base/policy_checks.h:77:2: error: #error "C++ versions less than C++11 are not supported."
 #error "C++ versions less than C++11 are not supported."
  ^
CMake Error at patches_generated_extract_patches.cu.o.Release.cmake:220 (message):
  Error generating
  /home/june/as2/attention-sampling-02/ats/ops/extract_patches/build/CMakeFiles/patches.dir//./patches_generated_extract_patches.cu.o


CMakeFiles/patches.dir/build.make:82: recipe for target 'CMakeFiles/patches.dir/patches_generated_extract_patches.cu.o' failed
make[2]: *** [CMakeFiles/patches.dir/patches_generated_extract_patches.cu.o] Error 1
CMakeFiles/Makefile2:94: recipe for target 'CMakeFiles/patches.dir/all' failed
make[1]: *** [CMakeFiles/patches.dir/all] Error 2
Makefile:148: recipe for target 'all' failed
make: *** [all] Error 2
Traceback (most recent call last):
  File "setup.py", line 37, in compile
    cwd=extension_dir
  File "/usr/lib/python3.6/subprocess.py", line 311, in check_call
    raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['make', '-j7']' returned non-zero exit status 2.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "setup.py", line 144, in <module>
    setup_package()
  File "setup.py", line 139, in setup_package
    cmdclass={"build_ext": custom_build_ext}
  File "/home/june/.local/lib/python3.6/site-packages/setuptools/__init__.py", line 153, in setup
    return distutils.core.setup(**attrs)
  File "/usr/lib/python3.6/distutils/core.py", line 148, in setup
    dist.run_commands()
  File "/usr/lib/python3.6/distutils/dist.py", line 955, in run_commands
    self.run_command(cmd)
  File "/usr/lib/python3.6/distutils/dist.py", line 974, in run_command
    cmd_obj.run()
  File "/usr/lib/python3.6/distutils/command/build.py", line 135, in run
    self.run_command(cmd_name)
  File "/usr/lib/python3.6/distutils/cmd.py", line 313, in run_command
    self.distribution.run_command(command)
  File "/usr/lib/python3.6/distutils/dist.py", line 974, in run_command
    cmd_obj.run()
  File "/home/june/.local/lib/python3.6/site-packages/setuptools/command/build_ext.py", line 79, in run
    _build_ext.run(self)
  File "/usr/lib/python3/dist-packages/Cython/Distutils/old_build_ext.py", line 185, in run
    _build_ext.build_ext.run(self)
  File "/usr/lib/python3.6/distutils/command/build_ext.py", line 339, in run
    self.build_extensions()
  File "/usr/lib/python3/dist-packages/Cython/Distutils/old_build_ext.py", line 193, in build_extensions
    self.build_extension(ext)
  File "setup.py", line 52, in build_extension
    ext.compile()
  File "setup.py", line 46, in compile
    ) from e
RuntimeError: Couldn't compile and install ats.ops.extract_patches.libpatches
june@DESKTOP-7JTR782:~/as2/attention-sampling-02$

Thanks for the paper and this library. If you need any more information about my environment, please ask me. Any help is welcomed. Thanks!

Unable to install on Macbook pro

Error:
error: can't copy 'ats/ops/extract_patches/libpatches.so': doesn't exist or not a regular file

libpatches.so file is not being generated.

Batch size for all the experiments in the papaer

Hey, I just want to repeat your work. I saw that your paper implied that the batch size for all experiments should be 1. However, I find that if I set the batch size to 1, I cannot get the same error as your experiments (about 10 times difference). But if I set the batch size to 32, I got decent result.

I would very appreciate your help to explain details in your experiments!

Extracting weird patches

Hi!

As I have been using this implementation for my own classification tasks, I have started to see a weird trend.
In the first couple of epochs, the patches seem to be taken from diffuse and widely differing areas of the training images (which is fine, and expected).
In the later epochs, as the training starts to converge, the attention model seems to focus on extracting patches from one particular edge in the images, where background (=black) meets the real content of the image. I can see this from the attention maps, as they appear white in this region, and from the patches that are extracted. This does not make any sense to me, as that particular edge does not reveal any important information for the classification task, but the algorithm is still able to achieve >80% accuracy from those patches.
Also, it is always the same edge, even though there are multiple similar edges in the training images.
Is this something you have experienced before?
Is there anything, you think, I could be doing wrong?

Thanks in advance,
Anders

Implementation of eq. 12

Hi, thanks for your paper and the code base.

I have a question about eq. 12. In the paper, the derivative is taken of features multiplied by attention scores.
However, in the backward pass (in ExpectWithReplacement), only the features are considered.

I probably misunderstand something, so I'd appreciate clarification. Thanks in advance.

Validation Accuracy Does not Change

I have used your code, but my problem is that validation accuracy stops after some epochs(at number 65) and does not change. I tried to change all the hyperparameters and it did not work. Have you ever had this problem before? I really appreciate it if you can help.

Why using random sampling during inference and not pick instead the X patches with maximum attention?

Hi all,

I was reading the paper to understand the implementation, but there is something strange to me.

If I understand correctly, the goal of using sampling in the training phase is to give to each patch an opportunity to have its attention score updated, when it is sampled from the distribution. They also prove it results in the minimum variance estimator.

But for inference, why don't we just pick the N patches with the best attention instead of repeating the same sampling process? How sampling can be more accurate than taking the best attention patches for inference, since the model has been trained?

What makes me confuse even more, is the fact the authors compare ATS-10 and ATS-50 for inference, but never talk about what sampling size they use during training.

TL;DR: Why sampling during inference and not taking the maximum attention values?

I also wonder about the manual selection of the patch size? Does it mean this algorithm will be inefficient for classification tasks where objects can represent a different proportion of the image? Can't this work be adapted for object detection task, similar to yolo?

Installation document no longer available

Hi,

I am trying to install this library and have encountered similar issues posted by other people. It's likely that I overlooked some dependencies and I wonder if you could share the installation document here as it's no longer available under the original link.

Thanks a lot!

expected_with_replacement

Hi,

I recently read this work, it's really a good idea and work!
but I have a question, if we change function _expected_with_replacement:

@K.tf.custom_gradient
def _expected_with_replacement(weights, attention, features):
    """Approximate the expectation as if the samples were i.i.d. from the
    attention distribtution.
    The gradient is simply scaled wrt to the sampled attention probablity to
    account for samples that are unlikely to be chosen.
    """
    # Compute the expectation
    wf = expand_many(weights, [-1] * (K.ndim(features) - 2))
    F = K.sum(wf * features, axis=1)

    # Compute the gradient
    def gradient(grad):
        grad = K.expand_dims(grad, 1)

        # Gradient wrt to the attention
        ga = grad * features
        ga = K.sum(ga, axis=list(range(2, K.ndim(ga))))
        ga = ga * weights / attention

        # Gradient wrt to the features
        gf = wf * grad

        return [None, ga, gf]

    return F, gradient

to:

def _expected_with_replacement(weights, attention, features):
    wf = expand_many(weights, [-1] * (K.ndim(features) - 2))
    F = K.sum(wf * features, axis=1)

that means, we don not use the back-propagation method mentioned in the paper.
In this case, end-to-end training can also be done.

Will the experimental results become worse in this condition?

Thanks!

file not found

Hi,

i am trying to install the package via pip and there seem to be a problem with the:
ats.ops.extract_patches.libpatches

so i downloaded the repo and tried installing manually but i get:
Capture

i have tensorflow,cuda,g++ etc.

do you have any idea how to solve it?

RuntimeError: Couldn't compile and install ats.ops.extract_patches.libpatches

Hello,
I am trying to install the package manually with "python setup.py instsall". However, I have a RuntimeError: Couldn't compile and install ats.ops.extract_patches.libpatches
image
image
I have tried the method from other posts, including install it by python and adding set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") to Cmakelist.txt. However, I still cannot solve my problem.

What's the softmax temperature?

I'm building this from scratch to avoid the additional CPP code, and it seems to be working, however, when I compute





and use that as logits for the softmax, the "a" term on large images tend to be very small, and "f" is normalized, so the scalar product gets logits very close to 0, and since softmax is not scale invariant, I get close to uniform predictions...

Thus, I think that you are using some form of temperature, but nor in the code neither in the paper i see any reference to it... can I have some clarification?

At the moment the best I was able to to do (not to handpick the temperature) is to make it trainable:

class SoftmaxWithTemperature(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.t = self.add_weight("temperature", (1,), initializer=tf.initializers.ones())
    def call(self, inputs):
        return tf.nn.softmax(self.t * inputs)
classification_network = tf.keras.models.Sequential([
    SoftmaxWithTemperature()
])
classification_network(tf.reshape(features * probabilities, ...))

And while training, I see the temperature slowly increasing to 1.2, which is still not enough (at least in my case)

It's not learning

Hi,
The model seems to not learn at all - I am using all default hyperparameters when running mnist.py:
abc@sdur-3:~/ats/scripts$ python3 ./mnist.py --epochs 10 ~/ats/datadir ~/ats/model_output

The loss doesn't really drop (or only very marginally) - does this look right to you?

I have already looked into perhaps the data was not loaded correctly, but it seems fine.

I have attached a snippet of the first 10 epochs, where you can see that it doesn't really learn.

Attention-Model-MNIST-test

Offsets for extracting patches

Hi, I have a couple of questions about offsets used to crop the original images:

  • An offset for an image I is [I, X, Y], where X varies on height and Y on width, correct?
  • What would mean an offset with a negative X or Y value at this line?

Thanks in advance!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.