Git Product home page Git Product logo

Comments (5)

zsk-tech avatar zsk-tech commented on August 23, 2024

确实是,我也觉得这个卷积层的参数无法训练,请问您解决这个问题了吗?

from danet-keras.

wujiayi avatar wujiayi commented on August 23, 2024

我改写成这样可以了:

from keras.layers import Activation, Conv2D
import keras.backend as K
import tensorflow as tf
from keras.layers import Layer

class PAM(Layer):

def __init__(self,
             beta_initializer   = tf.zeros_initializer(),
             beta_regularizer   = None,
             beta_constraint    = None,
             kernal_initializer = 'he_normal',
             kernal_regularizer = None,
             kernal_constraint  = None,
             **kwargs):
    super(PAM, self).__init__(**kwargs)

    self.beta_initializer = beta_initializer
    self.beta_regularizer = beta_regularizer
    self.beta_constraint  = beta_constraint

    self.kernal_initializer = kernal_initializer
    self.kernal_regularizer = kernal_regularizer
    self.kernal_constraint  = kernal_constraint

def build(self, input_shape):

    _, h, w, filters = input_shape

    self.beta = self.add_weight(shape=(1, ),
                                 initializer=self.beta_initializer,
                                 name='beta',
                                 regularizer=self.beta_regularizer,
                                 constraint=self. beta_constraint,
                                 trainable=True)
    #print(self.beta)

    self.kernel_b = self.add_weight(shape=(filters, filters // 8),
                                  initializer=self.kernal_initializer,
                                  name='kernel_b',
                                  regularizer=self.kernal_regularizer,
                                  constraint =self.kernal_constraint,
                                  trainable=True)

    self.kernel_c = self.add_weight(shape=(filters, filters // 8),
                                    initializer=self.kernal_initializer ,
                                    name='kernel_c',
                                    regularizer=self.kernal_regularizer,
                                    constraint=self.kernal_constraint,
                                    trainable=True)

    self.kernel_d = self.add_weight(shape=(filters, filters),
                                    initializer=self.kernal_initializer,
                                    name='kernel_d',
                                    regularizer=self.kernal_regularizer,
                                    constraint=self.kernal_constraint,
                                    trainable=True)

    self.built = True

def compute_output_shape(self, input_shape):
    return input_shape

def call(self, inputs):

    input_shape = inputs.get_shape().as_list()
    _, h, w, filters = input_shape

    b = K.dot(inputs, self.kernel_b)
    c = K.dot(inputs, self.kernel_c)
    d = K.dot(inputs, self.kernel_d)
    vec_b       = K.reshape(b, (-1, h * w, filters // 8))
    vec_cT      = K.permute_dimensions(K.reshape(c, (-1, h * w, filters // 8)), (0, 2, 1))
    bcT         = K.batch_dot(vec_b, vec_cT)
    softmax_bcT = Activation('softmax')(bcT)
    vec_d       = K.reshape(d, (-1, h * w, filters))
    bcTd        = K.batch_dot(softmax_bcT, vec_d)
    bcTd        = K.reshape(bcTd, (-1, h, w, filters))

    out   = self.beta*bcTd + inputs
    #print(self.beta)
    return out

class CAM(Layer):

def __init__(self,
             gamma_initializer=tf.zeros_initializer(),
             gamma_regularizer=None,
             gamma_constraint=None,
             **kwargs):
    super(CAM, self).__init__(**kwargs)
    self.gamma_initializer = gamma_initializer
    self.gamma_regularizer = gamma_regularizer
    self.gamma_constraint  = gamma_constraint

def build(self, input_shape):
    self.gamma = self.add_weight(shape=(1, ),
                                 initializer=self.gamma_initializer,
                                 name='gamma',
                                 regularizer=self.gamma_regularizer,
                                 constraint=self.gamma_constraint)
    #print(self.gamma)

    self.built = True

def compute_output_shape(self, input_shape):
    return input_shape

def call(self, inputs):
    input_shape = inputs.get_shape().as_list()
    _, h, w, filters = input_shape

    vec_a  = K.reshape(inputs, (-1, h * w, filters))
    vec_aT = K.permute_dimensions(K.reshape(vec_a, (-1, h * w, filters)), (0, 2, 1))
    aTa    = K.batch_dot(vec_aT, vec_a)
    softmax_aTa = Activation('softmax')(aTa)
    aaTa        = K.batch_dot(vec_a, softmax_aTa)
    aaTa        = K.reshape(aaTa, (-1, h, w, filters))

    out = self.gamma*aaTa + inputs
    #print(self.gamma)
    return out

from danet-keras.

AlanLu0808 avatar AlanLu0808 commented on August 23, 2024

from danet-keras.

wujiayi avatar wujiayi commented on August 23, 2024

看起来挺好的!有测试吗?性能有提升吗? 我之前测试发现了这个问题,一直没有解决,就把这个问题搁置了。

我试了,是可以train的,而且参数我也算了是对的。但是效果提升并不是很明显。不知道是不是还有别的什么问题。

from danet-keras.

karryor avatar karryor commented on August 23, 2024

你好,你改了之后有提升吗

from danet-keras.

Related Issues (8)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.