Git Product home page Git Product logo

bmm_attentional_cnn's Introduction

BMM_attentional_CNN

A CNN with an attentional module that I built while attending the brains minds and machines summer course

BMM_attention_model.py trains the models. BMM_attention_results.py creates the plots.

bmm_attentional_cnn's People

Contributors

dvatterott avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

bmm_attentional_cnn's Issues

hey,i'm so sorry bother you but i really feel so crazy when i try to add your attention block to my model ,so can u help me fix it?

import os
from keras.models import Model
from keras.layers.core import Dense, Dropout, Activation, Reshape, Permute, RepeatVector, Lambda
from keras.layers.convolutional import Conv2D, Conv2DTranspose, ZeroPadding2D
from keras.layers.pooling import AveragePooling2D, GlobalAveragePooling2D
from keras.layers import Input, Flatten
from keras.layers.merge import concatenate, multiply

from keras.layers.normalization import BatchNormalization
from keras.regularizers import l2
from keras.layers.wrappers import TimeDistributed
from keras.layers.recurrent import GRU, LSTM
from keras.layers.wrappers import Bidirectional
from keras.engine.topology import Layer, InputSpec
from keras import initializers as initializations
import keras.backend as K
from attention_utils import get_activations, get_data_recurrent
import tensorflow as tf
from keras import backend as K
from keras import regularizers, constraints, initializers, activations
from keras.layers.recurrent import Recurrent
from keras.engine import InputSpec
from tdd import _time_distributed_dense
import numpy as np


def attention_3d_block(inputs):
    input_dim = int(inputs.shape[1])
    a = Permute((2, 1))(inputs)
    a = Dense(input_dim, activation='softmax')(a)
    a_probs = Permute((2, 1), name='attention_vec')(a)
    # print("a_probs shape :   ",a_probs.shape)
    output_attention_mul = multiply([inputs, a_probs], name='attention_mul')
    return output_attention_mul



def conv_block(input, growth_rate, dropout_rate=None, weight_decay=1e-4):
    x = BatchNormalization(axis=-1, epsilon=1.1e-5)(input)
    x = Activation('relu')(x)
    x = Conv2D(growth_rate, (3, 3), kernel_initializer='he_normal', padding='same')(x)
    if (dropout_rate):
        x = Dropout(dropout_rate)(x)
    return x


def dense_block(x, nb_layers, nb_filter, growth_rate, droput_rate=0.2, weight_decay=1e-4):
    for i in range(nb_layers):
        cb = conv_block(x, growth_rate, droput_rate, weight_decay)
        x = concatenate([x, cb], axis=-1)
        nb_filter += growth_rate
    return x, nb_filter


def transition_block(input, nb_filter, dropout_rate=None, pooltype=1, weight_decay=1e-4):
    x = BatchNormalization(axis=-1, epsilon=1.1e-5)(input)
    x = Activation('relu')(x)
    x = Conv2D(nb_filter, (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False,
               kernel_regularizer=l2(weight_decay))(x)

    if (dropout_rate):
        x = Dropout(dropout_rate)(x)

    if (pooltype == 2):
        x = AveragePooling2D((2, 2), strides=(2, 2))(x)
    elif (pooltype == 1):
        x = ZeroPadding2D(padding=(0, 1))(x)
        x = AveragePooling2D((2, 2), strides=(2, 1))(x)
    elif (pooltype == 3):
        x = AveragePooling2D((2, 2), strides=(2, 1))(x)
    return x, nb_filter


def global_average_pooling(x):
    return K.mean(x, axis = (2, 3))

def global_average_pooling_shape(input_shape):
    return input_shape[0:2]

def change_shape1(x):
    x = K.reshape(K.transpose(x),(420,64))
    print("change shape",x.shape)
    return x


def att_shape(input_shape):
    return (input_shape[0][0],14,138,64)

def att_shape2(input_shape):
    return input_shape[0][0:4]

def attention_control(args):
    x,dense_2 = args
    print("attention shape ",x.shape)
    find_att = K.reshape(x,(7,60,64))
    print("find_att0 shape : ",find_att.shape)
    # find_att = K.transpose(find_att[:,:,:])
    find_att = K.mean(find_att,axis=1)
    find_att = find_att/K.sum(find_att,axis=1)
    print("---find_att :",find_att.shape)
    find_att = K.repeat_elements(find_att,300,axis=1)
    print("find_att 1 shape  ",find_att.shape)
    find_att = K.reshape(find_att,(1,7,60,64)) #(?, 16, 140, 64)
    print("find_att2 shape :",find_att.shape)
    return find_att
def dense_cnn(input, nclass):
    rnnunit = 256
    units = 256
    _dropout_rate = 0.2
    _weight_decay = 1e-4

    _nb_filter = 64
    # conv 64  5*5 s=2
    x0 = Conv2D(_nb_filter, (5, 5), strides=(2, 2), kernel_initializer='he_normal', padding='same',
               use_bias=False, kernel_regularizer=l2(_weight_decay))(input)# (?, 16, 140, 64)
    print("x0 shape",x0.shape)
    # 64 +  8 * 8 = 128
    x1, _nb_filter = dense_block(x0, 8, _nb_filter, 8, None, _weight_decay)
    # 128
    x2, _nb_filter = transition_block(x1, 128, _dropout_rate, 2, _weight_decay)

    # 128 + 8 * 8 = 192
    x, _nb_filter = dense_block(x2, 8, _nb_filter, 8, None, _weight_decay)
    # 192->128
    # x=attenton_cnn(x)
    x, _nb_filter = transition_block(x, 128, _dropout_rate, 2, _weight_decay)

    # 128 + 8 * 8 = 192
    x, _nb_filter = dense_block(x, 8, _nb_filter, 8, None, _weight_decay)

    x = BatchNormalization(axis=-1, epsilon=1.1e-5)(x)
    # (None, 4, 35, 192)
    dense_1 = Lambda(global_average_pooling,output_shape=global_average_pooling_shape,name='dense_1')(x)  # (,32)
    dense_2 = Dense(10, activation='softmax', name='dense_2')(dense_1)  # (,10)

    con_shape1 = Lambda(change_shape1, output_shape=(64,), name='change_shape1')(x)
    print(con_shape1.shape)
    find_att = Dense(64, activation='softmax', name='att_con')(con_shape1)
    print("dense find att shape ",find_att.shape)
    find_att = Lambda(attention_control, output_shape=att_shape, name="att_con")([find_att, dense_2])
    zero_3a = ZeroPadding2D((8, 50), name='convzero_3')(find_att)
    print("==find att==",find_att.shape)
    apply_attention = multiply([x0,zero_3a])
    x = Activation('relu')(apply_attention)


    return x


def dense_blstm(input):
    pass


if __name__ == "__main__":
    input = Input(shape=(32, 280, 1), name='the_input')
    y_pred = dense_cnn(input, 15)
    basemodel = Model(inputs=input, outputs=y_pred)
    basemodel.summary()

it's a densent i want to add attention
i would appreciate it if u can fix my code

Hi

I have some question when i watch your model picture and yout code;
from your model picture,i see you apply the focused result to your conv2 and passing the focused result to the next conv3;
But from your code,you multiply the conv2 and conv1 result and pass it to the next conv3 instead of passing to conv2 .I'm confused about that.
I would appreciate it if you can answer my question,thank you

Issue with BMM_attention_model.py

It seems like there's an error in your code. I copied and pasted the code of BMM_attention_model.py into a jupyter notebook to try it out and got the error message "Exception: Input 0 is incompatible with layer dense_2: expected shape=(None, 4), found shape=(None, 32)" caused by the line find_att = dense_2a(conv_shape1).

Any thoughts?

Photo of error

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.