Git Product home page Git Product logo

Comments (6)

ThorstenFalk avatar ThorstenFalk commented on August 25, 2024

In principle your approach is fine. caffe stores the weights in order (cOut, cIn, y, x) for convolution and (cIn, cOut, y, x) for up-convolution. Probably dimensions must be permuted a little more to work with the corresponding keras layers? I don't know the native dimension order of keras/tensorflow by heart. But this is my first guess of what might cause your strange outputs.

What kind of padding do you use? The plugin processes the image in overlapping tiles and uses mirroring to extrapolate data across image boundaries. You probably used zero padding leading to this wide bright border?

from unet-segmentation.

maxclac avatar maxclac commented on August 25, 2024

In TensorFlow, I use the Conv2D object with padding='same'. I changed it to 'valid', but it causes some problem now with layer concatenations. In my network, it is done like this:

def UNet_Freiburg(shape):
    inputs = tfk.layers.Input(shape=shape, name='input')
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d0a-b')(inputs)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d0b-c')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2), name='maxpool1')(conv1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d1a-b')(pool1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d1b-c')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2), name='maxpool2')(conv2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d2a-b')(pool2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d2b-c')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2), name='maxpool3')(conv3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d3a-b')(pool3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d3b-c')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2), name='maxpool4')(drop4)

    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d4a-b')(pool4)
    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_d4b-c')(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2D(512, 2, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='upconv_d4c_u3a')(UpSampling2D(size = (2,2))(drop5))
    merge6 = concatenate([drop4,up6], axis = 3)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u3b-c')(merge6)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u3c-d')(conv6)

    up7 = Conv2D(256, 2, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='upconv_u3d_u2a')(UpSampling2D(size = (2,2))(conv6))
    merge7 = concatenate([conv3,up7], axis = 3)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u2b-c')(merge7)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u2c-d')(conv7)

    up8 = Conv2D(128, 2, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='upconv_u2d_u1a')(UpSampling2D(size = (2,2))(conv7))
    merge8 = concatenate([conv2,up8], axis = 3)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u1b-c')(merge8)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u1c-d')(conv8)

    up9 = Conv2D(128, 2, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='upconv_u1d_u0a')(UpSampling2D(size = (2,2))(conv8))
    merge9 = concatenate([conv1,up9], axis = 3)
    conv9 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u0b-c')(merge9)
    conv9 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u0c-d')(conv9)
    conv9 = Conv2D(2, 1, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal', name='conv_u0d-score')(conv9)
    conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)

    model = Model(inputs, conv10)

    model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
    return model

and I get this error message:

ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 56, 56, 512), (None, 47, 47, 512)]

I guess I need to rearrange the whole network in order to adapt to the padding change.

from unet-segmentation.

ThorstenFalk avatar ThorstenFalk commented on August 25, 2024

Ah, OK, you used padding, you're right U-Net uses valid convolutions. The left blob must be cropped to match the spatial shape of the right blob.

This is my keras implementation of U-Net:

class Unet2D:

  def __init__(self, snapshot=None, n_channels=1, n_classes=2, n_levels=4,
               n_features=64, name="U-Net"):

    self.concat_blobs = []

    self.n_channels = n_channels
    self.n_classes = n_classes
    self.n_levels = n_levels
    self.n_features = n_features
    self.name = name

    self.trainModel, self.padding = self._createModel(True)
    self.testModel, _ = self._createModel(False)

    if snapshot is not None:
      self.trainModel.load_weights(snapshot)
      self.testModel.load_weights(snapshot)

  def _weighted_categorical_crossentropy(self, y_true, y_pred, weights):
    return tf.losses.softmax_cross_entropy(
      y_true, y_pred, weights=weights, reduction=tf.losses.Reduction.MEAN)

  def _createModel(self, training):

    data = keras.layers.Input(shape=(None, None, self.n_channels), name="data")

    concat_blobs = []

    if training:
      labels = keras.layers.Input(
        shape=(None, None, self.n_classes), name="labels")
      weights = keras.layers.Input(shape=(None, None), name="weights")

    # Modules of the analysis path consist of two convolutions and max pooling
    for l in range(self.n_levels):
      t = keras.layers.LeakyReLU(alpha=0.1)(
        keras.layers.Conv2D(
          2**l * self.n_features, 3, padding="valid",
          kernel_initializer="he_normal",
          name="conv_d{}a-b".format(l))(data if l == 0 else t))
      concat_blobs.append(
        keras.layers.LeakyReLU(alpha=0.1)(
          keras.layers.Conv2D(
            2**l * self.n_features, 3, padding="valid",
            kernel_initializer="he_normal", name="conv_d{}b-c".format(l))(t)))
      t = keras.layers.MaxPooling2D(pool_size=(2, 2))(concat_blobs[-1])

    # Deepest layer has two convolutions only
    t = keras.layers.LeakyReLU(alpha=0.1)(
      keras.layers.Conv2D(
        2**self.n_levels * self.n_features, 3, padding="valid",
        kernel_initializer="he_normal",
        name="conv_d{}a-b".format(self.n_levels))(t))
    t = keras.layers.LeakyReLU(alpha=0.1)(
      keras.layers.Conv2D(
        2**self.n_levels * self.n_features, 3, padding="valid",
        kernel_initializer="he_normal",
        name="conv_d{}b-c".format(self.n_levels))(t))
    pad = 8

    # Modules in the synthesis path consist of up-convolution,
    # concatenation and two convolutions
    for l in range(self.n_levels - 1, -1, -1):
      name = "upconv_{}{}{}_u{}a".format(
        *(("d", l+1, "c", l) if l == self.n_levels - 1 else ("u", l+1, "d", l)))
      t = keras.layers.LeakyReLU(alpha=0.1)(
        keras.layers.Conv2D(
          2**l * self.n_features, 2, padding="same",
          kernel_initializer="he_normal", name=name)(
            keras.layers.UpSampling2D(size = (2,2))(t)))
      t = keras.layers.Concatenate()(
        [keras.layers.Cropping2D(cropping=int(pad / 2))(concat_blobs[l]), t])
      pad = 2 * (pad + 8)
      t = keras.layers.LeakyReLU(alpha=0.1)(
        keras.layers.Conv2D(
          2**l * self.n_features, 3, padding="valid",
          kernel_initializer="he_normal", name="conv_u{}b-c".format(l))(t))
      t = keras.layers.LeakyReLU(alpha=0.1)(
        keras.layers.Conv2D(
          2**l * self.n_features, 3, padding="valid",
          kernel_initializer="he_normal", name="conv_u{}c-d".format(l))(t))
    pad /= 2

    score = keras.layers.Conv2D(
      self.n_classes, 1, kernel_initializer = 'he_normal',
      name="conv_u0d-score")(t)
    softmax_score = keras.layers.Softmax()(score)

    if training:
      model = keras.Model(inputs=[data, labels, weights], outputs=softmax_score)
      model.add_loss(
        self._weighted_categorical_crossentropy(labels, score, weights))
      adam = keras.optimizers.Adam(
        lr=0.00001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0,
        amsgrad=False)
      model.compile(optimizer=adam, loss=None)
    else:
      model = keras.Model(inputs=data, outputs=softmax_score)

    return model, int(pad)

  def loadCaffeModelH5(self, path):
    train_layer_dict = dict([(layer.name, layer)
                             for layer in self.trainModel.layers])
    test_layer_dict = dict([(layer.name, layer)
                            for layer in self.testModel.layers])
    pre = h5py.File(path, 'a')
    l = list(pre['data'].keys())
    for i in l:
      kernel = pre['data'][i]['0'][()]
      bias = pre['data'][i]['1'][()]
      train_layer_dict[i].set_weights([kernel,bias])
      test_layer_dict[i].set_weights([kernel,bias])
    pre.close()

  def train(self, sample_generator, validation_generator=None,
            n_epochs=100, snapshot_interval=1, snapshot_prefix=None):

    callbacks = [TensorBoard(log_dir="logs/{}-{}".format(self.name, time()))]
    if snapshot_prefix is not None:
      callbacks.append(keras.callbacks.ModelCheckpoint(
        (snapshot_prefix if snapshot_prefix is not None else self.name) +
        ".{epoch:04d}.h5", mode='auto', period=snapshot_interval))
    self.trainModel.fit_generator(
      sample_generator, epochs=n_epochs, validation_data=validation_generator,
      verbose=1, callbacks=callbacks)

  def predict(self, tile_generator):

    smscores = []
    segmentations = []

    for tileIdx in range(tile_generator.__len__()):
      tile = tile_generator.__getitem__(tileIdx)
      outIdx = tile[0]["image_index"]
      outShape = tile[0]["image_shape"]
      outSlice = tile[0]["out_slice"]
      inSlice = tile[0]["in_slice"]
      softmax_score = self.testModel.predict(tile[0]["data"], verbose=1)
      if len(smscores) < outIdx + 1:
        smscores.append(np.empty((*outShape, self.n_classes)))
        segmentations.append(np.empty(outShape))
      smscores[outIdx][outSlice] = softmax_score[0][inSlice]
      segmentations[outIdx][outSlice] = np.argmax(
        softmax_score[0], axis=-1)[inSlice]

    return smscores, segmentations

I think this should make everything clear. It even includes loading the weights.

All the best,
Thorsten

from unet-segmentation.

maxclac avatar maxclac commented on August 25, 2024

Thank you! I will have a look.

from unet-segmentation.

maxclac avatar maxclac commented on August 25, 2024

This is unfortunately not TensorFlow 2, I will have to do some work to adapt it to my environment.

from unet-segmentation.

maxclac avatar maxclac commented on August 25, 2024

Hi again!
Now what I did is simply go back to TensorFlow 1.14, rather than porting the code to TF2.
I have now a problem that I already have before, namely that in

    def loadCaffeModelH5(self, path):
        train_layer_dict = dict([(layer.name, layer)
                                 for layer in self.trainModel.layers])
        test_layer_dict = dict([(layer.name, layer)
                                for layer in self.testModel.layers])
        pre = h5py.File(path, 'a')
        l = list(pre['data'].keys())
        for i in l:
            print(i)
            print(pre['data'][i].keys())
            try:
                kernel = pre['data'][i]['0'][()]
                bias = pre['data'][i]['1'][()]
                train_layer_dict[i].set_weights([kernel, bias])
                test_layer_dict[i].set_weights([kernel, bias])
            except KeyError:
                pass
        pre.close()

there is a conflict between the shapes of the layers in Caffe and the shapes expected by TF:

ValueError: Layer weight shape (3, 3, 1, 64) not compatible with provided weight shape (64, 1, 3, 3)

This could maybe be solved by taking the transpose of the arrays, but then there is the issue with the upconv layers:

ValueError: Layer weight shape (3, 3, 128, 64) not compatible with provided weight shape (3, 3, 192, 128)

These are the problems I had already before setting up this issue. Am I missing something?

from unet-segmentation.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.