Git Product home page Git Product logo

Comments (6)

angeloskath avatar angeloskath commented on July 28, 2024

Hi,

Does it achieve >80% also in the test set?

So I have seen this before and I usually deal with it by tweaking the following:

  • entropy regularization
  • sample softmax smoothing
  • attention network such that it does not produce artifacts in the edges of the full image (removing padding in convolutions basically)
  • clipnorm in the optimizer
  • l2 regularization for the attention network
  • removing black to image transitions (one could fill black with the average color of the dataset)

Just to make sure that I understand, the attention focuses on parts of the image that are at the side where there is some artifact. This is basically attention overfitting. The black to color causes large activations which the attention has not learned to ignore yet or they just influence the feature network to use them which creates a positive feedback loop and ends up using just them.

Let me know if you need more help or if anything I mentioned above does not make sense.

Cheers,
Angelos

from attention-sampling.

andersbhc-mmmi avatar andersbhc-mmmi commented on July 28, 2024

Yes, it does.

You are correct, the attention model is focusing on the edge between background (black) and foreground (real content, not black).

I will try some of the points you have mentioned here. I already tried tweaking the regularization strength for the attention model, controlling its exploration vs exploitation. When adding more regularization, the model seemed to distribute the patches more (do more exploration), but eventually ended up focusing on the same edges.

Also, I tried to save the models' weights using the Keras ModelCheckpoint and load them again after initializing the model, but then the accuracy drops significantly.
Here is my code for building the models, saving and loading weights:

#Defining the attention model
def getAttentionModel(input_shape):
    attention = Sequential([
        Conv2D(8, kernel_size=3, activation="relu", padding="same",
               input_shape=input_shape),
        Conv2D(16, kernel_size=3, activation="relu", padding="same"),
        Conv2D(32, kernel_size=3, activation="relu", padding="same"),
        Conv2D(64, kernel_size=3, activation="relu", padding="same"), 
        Conv2D(128, kernel_size=3, activation="relu", padding="same"),
        Conv2D(1, kernel_size=3, padding="same"),
        SampleSoftmax(squeeze_channels=True, smooth=1e-5)
    ])

    return attention
#Defining the feature extraction model
def getVGGModel(input_shape, pre_trained=True):
    if pre_trained:
        model = VGG16(include_top=False, weights='imagenet', input_shape=input_shape, pooling='max')
    else:
        model = VGG16(include_top=False, weights=None, input_shape=input_shape, pooling='max')

    #We only fine-tune the last CONV layer and the Dense layers
    for layer in model.layers[:15]:
        layer.trainable = False

    last = model.output
    x = L2Normalize()(last)
    x = Dropout(0.5)(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.5)(x)

    feature_model = Model(inputs=model.input, outputs=x)

    return feature_model
#Making the attention sampling network
def get_model(outputs, width, height, scale, n_patches, patch_size, reg):
    # Define the shapes
    x_in = Input(shape=(height, width, 3))
    x_high = x_in
    x_low = ResizeImages((int(height*scale), int(width*scale)))(x_high)
    shape_high = (height, width, 3)
    shape_low = (int(height*scale), int(width*scale), 3)

    # Make the attention and feature models
    attention = getAttentionModel(shape_low)
    feature = getVGGModel(shape_high, pre_trained=True)

    # Let's build the attention sampling network
    features, attention, patches = attention_sampling(
        attention,
        feature,
        patch_size,
        n_patches,
        replace=False,
        attention_regularizer=multinomial_entropy(reg)
    )([x_low, x_high])
    y = Dense(outputs, activation="softmax")(features)

    return (
        Model(inputs=x_in, outputs=[y]),
        Model(inputs=x_in, outputs=[attention, patches, x_low])
    )
#Instantiating the models
model, att_model = get_model(
    outputs=2,
    width=1444,
    height=1184,
    scale=0.2,
    n_patches=args["n_patches"],
    patch_size=args["patch_size"],
    reg=args["regularizer_strength"]
)
#Instantiating callbacks including ModelCheckpoint which saves weights after each epoch
callbacks = [
    lr_sched,
    AttentionSaver(args["output"], att_model, training_set),
    ModelCheckpoint(
        os.path.join(args["output"], "weights.{epoch:02d}.h5"),
        save_weights_only=True
    ),
    CSVLogger(filename=os.path.join(args["output"], "train_history.csv"))
]

Then I fit the model.


When I then recreate the models in the same way as above and load the weights:

model.load_weights(args["weights_path"])
Then evaluating the model on the same data as before yields a low accuracy.

  • Any ideas?

from attention-sampling.

angeloskath avatar angeloskath commented on July 28, 2024

Hi,

The fact that the test accuracy does not drop means that the patches are informative, so it will be harder to get rid of them.

Regarding saving and loading, that is weird. I just copy/pasted your code in a shell and saving and loading works fine. I would start by comparing outputs for a single image. Load your code in a shell if possible and then train for a few gradient updates and then save and reload and check that the attention is exactly the same and that for a given image the two models give approximately equal results. It helps checking the deterministic parts of the code because they should match exactly!

Let me know if you still have problems saving and loading the models or if I can help with the weird patches being selected.

If you really want to get rid of them no matter what, you could generate a mask and apply it to the attention so that patches that contain some black are never selected.

Cheers,
Angelos

from attention-sampling.

andersbhc-mmmi avatar andersbhc-mmmi commented on July 28, 2024

Hi Angelos,

So I finally made saving and loading work, and I can now reproduce the results after saving and loading the model and its weights.
I used the built-in methods model_to_json() and model_from_json() along with the ModelCheckpoint to save the weights and load_weights() to load the weights in my test script.
I also had to define some get_config() methods in the custom layers to make it work.

I'm still confused as to how the patches containing the edges help in the classification. I'll continue tweaking some hyperparameters and see if I can make it better.

Thank you for your help!

from attention-sampling.

angeloskath avatar angeloskath commented on July 28, 2024

This sounds awesome, do you mind sharing your additions? I would gladly merge a pull request.

Let me know if I can do anything more.

Angelos

from attention-sampling.

andersbhc-mmmi avatar andersbhc-mmmi commented on July 28, 2024

Hi,
Yeah definitely. It's not much, but I'll share it for sure.

Anders

from attention-sampling.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.