Comments (6)
Hi,
Does it achieve >80% also in the test set?
So I have seen this before and I usually deal with it by tweaking the following:
- entropy regularization
- sample softmax smoothing
- attention network such that it does not produce artifacts in the edges of the full image (removing padding in convolutions basically)
- clipnorm in the optimizer
- l2 regularization for the attention network
- removing black to image transitions (one could fill black with the average color of the dataset)
Just to make sure that I understand, the attention focuses on parts of the image that are at the side where there is some artifact. This is basically attention overfitting. The black to color causes large activations which the attention has not learned to ignore yet or they just influence the feature network to use them which creates a positive feedback loop and ends up using just them.
Let me know if you need more help or if anything I mentioned above does not make sense.
Cheers,
Angelos
from attention-sampling.
Yes, it does.
You are correct, the attention model is focusing on the edge between background (black) and foreground (real content, not black).
I will try some of the points you have mentioned here. I already tried tweaking the regularization strength for the attention model, controlling its exploration vs exploitation. When adding more regularization, the model seemed to distribute the patches more (do more exploration), but eventually ended up focusing on the same edges.
Also, I tried to save the models' weights using the Keras ModelCheckpoint and load them again after initializing the model, but then the accuracy drops significantly.
Here is my code for building the models, saving and loading weights:
#Defining the attention model
def getAttentionModel(input_shape):
attention = Sequential([
Conv2D(8, kernel_size=3, activation="relu", padding="same",
input_shape=input_shape),
Conv2D(16, kernel_size=3, activation="relu", padding="same"),
Conv2D(32, kernel_size=3, activation="relu", padding="same"),
Conv2D(64, kernel_size=3, activation="relu", padding="same"),
Conv2D(128, kernel_size=3, activation="relu", padding="same"),
Conv2D(1, kernel_size=3, padding="same"),
SampleSoftmax(squeeze_channels=True, smooth=1e-5)
])
return attention
#Defining the feature extraction model
def getVGGModel(input_shape, pre_trained=True):
if pre_trained:
model = VGG16(include_top=False, weights='imagenet', input_shape=input_shape, pooling='max')
else:
model = VGG16(include_top=False, weights=None, input_shape=input_shape, pooling='max')
#We only fine-tune the last CONV layer and the Dense layers
for layer in model.layers[:15]:
layer.trainable = False
last = model.output
x = L2Normalize()(last)
x = Dropout(0.5)(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
feature_model = Model(inputs=model.input, outputs=x)
return feature_model
#Making the attention sampling network
def get_model(outputs, width, height, scale, n_patches, patch_size, reg):
# Define the shapes
x_in = Input(shape=(height, width, 3))
x_high = x_in
x_low = ResizeImages((int(height*scale), int(width*scale)))(x_high)
shape_high = (height, width, 3)
shape_low = (int(height*scale), int(width*scale), 3)
# Make the attention and feature models
attention = getAttentionModel(shape_low)
feature = getVGGModel(shape_high, pre_trained=True)
# Let's build the attention sampling network
features, attention, patches = attention_sampling(
attention,
feature,
patch_size,
n_patches,
replace=False,
attention_regularizer=multinomial_entropy(reg)
)([x_low, x_high])
y = Dense(outputs, activation="softmax")(features)
return (
Model(inputs=x_in, outputs=[y]),
Model(inputs=x_in, outputs=[attention, patches, x_low])
)
#Instantiating the models
model, att_model = get_model(
outputs=2,
width=1444,
height=1184,
scale=0.2,
n_patches=args["n_patches"],
patch_size=args["patch_size"],
reg=args["regularizer_strength"]
)
#Instantiating callbacks including ModelCheckpoint which saves weights after each epoch
callbacks = [
lr_sched,
AttentionSaver(args["output"], att_model, training_set),
ModelCheckpoint(
os.path.join(args["output"], "weights.{epoch:02d}.h5"),
save_weights_only=True
),
CSVLogger(filename=os.path.join(args["output"], "train_history.csv"))
]
Then I fit the model.
When I then recreate the models in the same way as above and load the weights:
model.load_weights(args["weights_path"])
Then evaluating the model on the same data as before yields a low accuracy.
- Any ideas?
from attention-sampling.
Hi,
The fact that the test accuracy does not drop means that the patches are informative, so it will be harder to get rid of them.
Regarding saving and loading, that is weird. I just copy/pasted your code in a shell and saving and loading works fine. I would start by comparing outputs for a single image. Load your code in a shell if possible and then train for a few gradient updates and then save and reload and check that the attention is exactly the same and that for a given image the two models give approximately equal results. It helps checking the deterministic parts of the code because they should match exactly!
Let me know if you still have problems saving and loading the models or if I can help with the weird patches being selected.
If you really want to get rid of them no matter what, you could generate a mask and apply it to the attention so that patches that contain some black are never selected.
Cheers,
Angelos
from attention-sampling.
Hi Angelos,
So I finally made saving and loading work, and I can now reproduce the results after saving and loading the model and its weights.
I used the built-in methods model_to_json() and model_from_json() along with the ModelCheckpoint to save the weights and load_weights() to load the weights in my test script.
I also had to define some get_config() methods in the custom layers to make it work.
I'm still confused as to how the patches containing the edges help in the classification. I'll continue tweaking some hyperparameters and see if I can make it better.
Thank you for your help!
from attention-sampling.
This sounds awesome, do you mind sharing your additions? I would gladly merge a pull request.
Let me know if I can do anything more.
Angelos
from attention-sampling.
Hi,
Yeah definitely. It's not much, but I'll share it for sure.
Anders
from attention-sampling.
Related Issues (20)
- RuntimeError: Couldn't compile and install ats.ops.extract_patches.libpatches HOT 4
- file not found HOT 2
- Allow use of a patch generator HOT 5
- Offsets for extracting patches HOT 4
- Why using random sampling during inference and not pick instead the X patches with maximum attention? HOT 1
- C++ versions less than C++11 are not supported
- Suggestion of Environment (OS, package version, etc.) HOT 1
- Implementation of eq. 12 HOT 2
- Validation Accuracy Does not Change HOT 1
- MNIST noise overlaps signal
- expected_with_replacement
- Installation document no longer available
- Segmentation fault (core dumped) HOT 2
- What's the softmax temperature? HOT 1
- pip install runtime error: Couldn't compile and install ats.ops.extract_patches.libpatches HOT 4
- Unable to install on Macbook pro HOT 4
- It's not learning HOT 2
- Batch size for all the experiments in the papaer HOT 2
- What is the role of "receptive field"? HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from attention-sampling.