Comments (9)
@5uperpalo, can you have a look to this? 👆🏼
from pytorch-widedeep.
sure, I'll look into it and respond by tomorrow lunch
from pytorch-widedeep.
I am sorry for late response @taokz, I had some personal issues holding me back... last time I did not upload the latest version of the troubleshooting notebook and yes, you were right I used the wrong link. I updated the troubleshooting notebook I posted earlier. There you can see in the section ISSUE num.2
that TabTransformer can work without categorical features.
As you are working with private/proprietary data, I would suggest the following:
- in the code you may see that if you set objective to
binary
the los default toBCEWithLogitsLoss
, i.e. here and here- this is a same loss as use used in TabNet model in in the issue you resolved earlier
- Try to use
import pdb; pdb.set_trace()
inside thetrainer.fit()
to debug what ground truth and predicted values are you sending to loss function, e.g. usepdb.set_trace()
here or here - maybe the model has na/infinity/? values on the output, so try to (i)normalize the columns by some Scaler in scikit-learn or just by enabling scaling in Preprocessor, i.e. parameter here; (ii) use different initializer than
XavierNormal
Note: Please let me know if any of this helped. It could help other people, including us, if we come across the same issue.
from pytorch-widedeep.
hi @taokz , here is a full functional code using a dataset with ALL continuous cols. Maybe you could use this as a starting point to fix the issue you are experiencing:
import numpy as np
from pytorch_widedeep import Trainer
from pytorch_widedeep.models import TabTransformer, WideDeep
from pytorch_widedeep.datasets import load_california_housing
from pytorch_widedeep.callbacks import (
EarlyStopping,
ModelCheckpoint,
)
from pytorch_widedeep.preprocessing import TabPreprocessor
if __name__ == "__main__":
df = load_california_housing(as_frame=True)
df["location_x"] = np.cos(df.Latitude) * np.cos(
df.Longitude
)
df["location_y"] = np.cos(df.Longitude) * np.sin(
df.Longitude
)
df.drop(["Latitude", "Longitude"], axis=1, inplace=True)
target_col = "MedHouseVal"
target = df[target_col].values
continuous_cols = [c for c in df.columns if c != target_col]
tab_preprocessor = TabPreprocessor(
continuous_cols=continuous_cols,
cols_to_scale=continuous_cols,
for_transformer=True,
)
X_tab = tab_preprocessor.fit_transform(df)
tab_transformer = TabTransformer(
column_idx=tab_preprocessor.column_idx,
continuous_cols=continuous_cols,
embed_continuous=True,
input_dim=8,
n_blocks=1,
n_heads=2,
)
model = WideDeep(deeptabular=tab_transformer)
callbacks = [
EarlyStopping(patience=2),
ModelCheckpoint(filepath="model_weights/wd_out"),
]
trainer = Trainer(
model,
objective="regression",
callbacks=callbacks,
)
trainer.fit(
X_tab=X_tab,
target=target,
n_epochs=10,
batch_size=128,
val_split=0.2,
)
from pytorch-widedeep.
@taokz both issues are likely connected to the data you are using; can your share sample of the data? to anonymize it, you may change column names and use simple .sample()
method on the dataframe ...
What I did:
- installed clean pytorch_widedeep repo from master
- updated my libraries by
pip install -r requirements.txt -U
- I took an example binary classification notebook and adjusted it according to your reported issues, see the adjusted notebook
Next steps:
- try to make a clean install and updated your libraries by
pip install -r requirements.txt -U
- check if you have the correct data types in your dataframe (i.e. values in X_tab you are passing to Trainer), are they correct, e.g. no strings NAs, or objects?
- ISSUE num.1 : when you use
verbose=True
do yu have a non-NA loss? best_epoch is saved only if the monitor is is working and if monitored metric improves (be default validation loss ); do you see non-NA verbose output? - ISSUE num.2, again, must be related to the data that you are passing into the Trainer, if it's not possible to share the data, could you please try to compare it to the povided example notebook?
from pytorch-widedeep.
Perhaps the reason for the issue is that there are NaN values in your data.
I faced a similar problem, but it was resolved when I used dropna() on my data.
from pytorch-widedeep.
Perhaps the reason for the issue is that there are NaN values in your data. I faced a similar problem, but it was resolved when I used dropna() on my data.
@ibowennn Thank you for your reminder. However, I have checked my data and there are no NaN values.
from pytorch-widedeep.
@5uperpalo I kindly appreciate your quick reply.
For issue 1, I noticed that I wrongly used the fit() method:
trainer.fit(
X_tab=X_num_train,
target=y_train,
X_tab_val=X_num_valid, # there is no such a variable in the base_trainer
target_val=y_valid, # there is no such a variable in the base_trainer
n_epochs=2,
batch_size=1024,
)
I modified it to be the following, and it works (for TabNet).
trainer.fit(
X_tab=X_tab,
target=target,
n_epochs=2,
batch_size=1024,
val_split=0.2
)
However, I still can not solve the issue 2, and there is nan loss (for transfomers such as tab_transformer). I guess it is because I pass cat_embed_input=None, because my data just have continuous features. Is it required to set cat_embed_input != None for transformer-based models? The example notebook may has wrong link. Do you mean This?
BTW, I am sorry that the data is private and I can not share it here. We can just take it as a table with only numerical values, and there is no NaN.
from pytorch-widedeep.
@5uperpalo @jrzaurin Thank you for your detailed guidance! I am recently focusing on the other project, so I did not response in time. I appreciate your time and efforts!
from pytorch-widedeep.
Related Issues (20)
- <frozen importlib._bootstrap>:914 error when importing on Google Colab HOT 2
- Image Preprocessing takes a lot of time HOT 2
- Not Being able to reproduce Bert results HOT 5
- pytorch vision module error HOT 1
- CyclicLR throws ZeroDivisionError when finetuning with a single batch. HOT 2
- EarlyStopping does not store and restore the model HOT 5
- Can I use time series data HOT 6
- CUDA error: device-side assert triggered HOT 5
- Wrong paper links on ContrastiveDenoisingTrainer HOT 2
- how to save the best Epoch HOT 11
- Dropout layer being created on forward pass (in MultiHeadedAttention) HOT 1
- about Wide's input dim HOT 5
- ImportError: cannot import name 'LRScheduler' from 'torch.optim.lr_scheduler' HOT 8
- OSError when importing the package HOT 4
- AttributeError: 'TabMlp' object has no attribute 'with_fds' HOT 3
- Colab session crash on .fit HOT 3
- IndexError: index out of range in self HOT 4
- how to use lr warmup in traing stage? HOT 3
- Problems running transformer models HOT 5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-widedeep.