Comments (4)
Hi! Thank you for such a detailed report.
What I see from the error traceback:
- the error happens when
key_compression
is called key_compression
is a linear layer- the problem is that the output size of
key_compression
is zero
This happened because of the too low value of kv_compression_ratio=0.004
for this number of features. In a nutshell, key_compression
is created as follows:
key_compression = nn.Linear(
num_input_features,
int(kv_compression_ratio * num_input_features), # int(0.004 * 152) = 0
bias=False,
)
The purpose of this is to reduce the number of features to make the attention faster.
However, there is no point in reducing the number of features below a certain threshold. This threshold is purely heuristic and should be chosen based on your budget and the downstream performance. In particular, it depends on the number of features.
The best scenario is when you don't need compression (i.e. kv_compression_ratio=None
). If this does not fit into a budget, then choose kv_compression_ratio
from the arange [a, b]
based on your intuition and preference, where a
is the smallest value that still provides good performance in terms of metrics, and b
is the largest value that still fits into you budget.
Does this help?
from rtdl.
Dear @Yura52,
Thanks a lot for your answer, it does help!
I had to put such a low value for the kv_compression_ratio
parameter when the training was done on the full dataset (because of my available hardware and time budget), but it's true that it doesn't make a lot of sense to use the same value for this parameter when the first group of features (representing almost 99% of all of the input features) is removed during an ablation study.
So for the moment I simply changed my initialization of the model by modifying three lines:
class FTT(rtdl.FTTransformer):
def __init__(self, n_num_features=None, cat_cardinalities=None, d_token=16, n_blocks=1, attention_n_heads=4, attention_dropout=0.3, attention_initialization='kaiming', attention_normalization='LayerNorm', ffn_d_hidden=16, ffn_dropout=0.1, ffn_activation='ReGLU', ffn_normalization='LayerNorm', residual_dropout=0.0, prenormalization=True, first_prenormalization=False, last_layer_query_idx=[-1], n_tokens=None, kv_compression_ratio=0.004, kv_compression_sharing='headwise', head_activation='ReLU', head_normalization='LayerNorm', d_out=None):
feature_tokenizer = rtdl.FeatureTokenizer(
n_num_features=n_num_features,
cat_cardinalities=cat_cardinalities,
d_token=d_token
)
transformer = rtdl.Transformer(
d_token=d_token,
n_blocks=n_blocks,
attention_n_heads=attention_n_heads,
attention_dropout=attention_dropout,
attention_initialization=attention_initialization,
attention_normalization=attention_normalization,
ffn_d_hidden=ffn_d_hidden,
ffn_dropout=ffn_dropout,
ffn_activation=ffn_activation,
ffn_normalization=ffn_normalization,
residual_dropout=residual_dropout,
prenormalization=prenormalization,
first_prenormalization=first_prenormalization,
last_layer_query_idx=last_layer_query_idx,
n_tokens=None if int(kv_compression_ratio * n_num_features) == 0 else n_num_features + 1, # Modified line
kv_compression_ratio=None if int(kv_compression_ratio * n_num_features) == 0 else kv_compression_ratio, # Modified line
kv_compression_sharing=None if int(kv_compression_ratio * n_num_features) == 0 else "headwise", # Modified line
head_activation=head_activation,
head_normalization=head_normalization,
d_out=d_out
)
super(FTT, self).__init__(feature_tokenizer, transformer)
It's clearly not optimal and can be improved (e.g., by automatically setting the value of kv_compression_ratio with respect to the number of input features and not putting always 0.004), but for the moment it is sufficient as the code is running.
However, I'm having trouble understanding why the code was working when I was running it on a single GPU (1x RTX 2080 Ti) but not on two (2x RTX 2080 Ti). Could you explain this? I understand your answer, but I don't get why the code still runs locally with 152 input features and a kv_compression_ratio
of 0.004 (by "locally" I mean using my own computer that has one GPU).
from rtdl.
I should admit I don't have a good explanation for that :)
from rtdl.
Feel free to reopen the issue if needed!
from rtdl.
Related Issues (20)
- Code error, should be module_type and not str
- typos in CatEmbeddings
- wrong condition in _LVR_encoding
- LGBMRegressor on California Housing dataset is 0.68 >> 0.46 HOT 2
- embedding of categorical variables HOT 3
- Bugs in piecewise-linear encoding HOT 2
- Cannot link in the document of zero HOT 4
- Add additional validation when constructing PLE
- Regression results about the RTDL models. HOT 1
- How to resume training? HOT 3
- Typos? HOT 1
- rtdl example on dataset with categorical variables HOT 2
- How to get the probablity in each multiclass? HOT 1
- # bug, located in rtdl.data.piecewise_linear_encoding line #618 HOT 9
- when to support torch 2? HOT 2
- How to get feature importance scores or attention heatmap HOT 2
- About piecewise_linear_encoding HOT 4
- VAE from microsoft
- Possibly wrong initialization in LinearEmbeddings HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from rtdl.