Git Product home page Git Product logo

tabpfn's Introduction

TabPFN

The TabPFN is a neural network that learned to do tabular data prediction. This is the original CUDA-supporting pytorch impelementation.

We created a Colab, that lets you play with our scikit-learn interface.

Installation

pip install tabpfn

If you want to train and evaluate our method like we did in the paper (including baselines) please install with

pip install tabpfn[full]

To run the autogluon and autosklearn baseline please create a separate environment and install autosklearn==0.14.5 / autogluon==0.4.0, installation in the same environment as our other baselines is not possible.

Getting started

A simple usage of our sklearn interface is:

from sklearn.metrics import accuracy_score
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

from tabpfn import TabPFNClassifier

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

# N_ensemble_configurations controls the number of model predictions that are ensembled with feature and class rotations (See our work for details).
# When N_ensemble_configurations > #features * #classes, no further averaging is applied.

classifier = TabPFNClassifier(device='cpu', N_ensemble_configurations=32)

classifier.fit(X_train, y_train)
y_eval, p_eval = classifier.predict(X_test, return_winning_probability=True)

print('Accuracy', accuracy_score(y_test, y_eval))

TabPFN Usage

TabPFN is different from other methods you might know for tabular classification. Here, we list some tips and tricks that might help you understand how to use it best.

  • Do not preprocess inputs to TabPFN. TabPFN pre-processes inputs internally. It applies a z-score normalization (x-train_x.mean()/train_x.std()) per feature (fitted on the training set) and log-scales outliers heuristically. Finally, TabPFN applies a PowerTransform to all features for every second ensemble member. Pre-processing is important for the TabPFN to make sure that the real-world dataset lies in the distribution of the synthetic datasets seen during training. So to get the best results, do not apply a PowerTransformation to the inputs.
  • TabPFN expects scalar values only (if your categoricals are floats just leave them as they are, if you have categoricals that are not encoded as float (rather str or object), encode your categoricals e.g. with OrdinalEncoder). TabPFN works best on data that does not contain any categorical or NaN data (see Appendix B.1).
  • TabPFN ensembles multiple input encodings per default. It feeds different index rotations of the features and labels to the model per ensemble member. You can control the ensembling with TabPFNClassifier(...,N_ensemble_configurations=?)
  • TabPFN does not use any statistics from the test set. That means predicting each test example one-by-one will yield the same result as feeding the whole test set together.
  • TabPFN is differentiable in principle, only the pre-processing is not and relies on numpy.

Our Paper

Read our paper for more information about the setup (or contact us ☺️). If you use our method, please cite us using

@inproceedings{
  hollmann2023tabpfn,
  title={Tab{PFN}: A Transformer That Solves Small Tabular Classification Problems in a Second},
  author={Noah Hollmann and Samuel M{\"u}ller and Katharina Eggensperger and Frank Hutter},
  booktitle={The Eleventh International Conference on Learning Representations},
  year={2023},
  url={https://openreview.net/forum?id=cp5PvcI6w8_}
}

License

Copyright 2022 Noah Hollmann, Samuel Müller, Katharina Eggensperger, Frank Hutter

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

tabpfn's People

Contributors

amueller avatar david-schnurr avatar eddiebergman avatar fangwei123456 avatar frank-hutter avatar liuquangao avatar noahho avatar samuelgabriel avatar tabpfn-anonym avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tabpfn's Issues

Data Pre-Checks

Dear authors,

Once again, great contribution to the AutoML area ✅

Brief of the issue

Surprisingly, the following code snippet reveals that the Sklearn check array method is present in your code to verify whether or not the training X matrix/array is correct. Nonetheless, it is commented (i.e, not currently used). I foolishly followed the Readme even though I knew there would be errors, but nothing mentioned that my data had just NaN values and that it was not computable.

Naive and quick proposal

In the event of an error in the predict proba method, you should execute the check array that directs the user to the actual error. Nonetheless, I completely understand that checking the matrice X may be pointless if we consider that the data we feed the system is already compute-ready; yet, if there is an issue, which could be the case, I assume you should report this to the user of your system in the correct manner. Finally, this naive and quick proposal of reporting directly to the user only if there was a crash would not affect the computation time if the data is already suitable for computation, therefore benchmarks, etc. are unaffected (in theory).

Current code snippet:

    def predict_proba(self, X, normalize_with_test=False):
        # Check is fit had been called
        check_is_fitted(self)

        # Input validation
        # X = check_array(X)

Naive proposal as I do not have the time yet to try it due to research related duties:

        try:
            X_full = np.concatenate([self.X_, X], axis=0)
            X_full = torch.tensor(X_full, device=self.device).float().unsqueeze(1)
            y_full = np.concatenate([self.y_, np.zeros_like(X[:, 0])], axis=0)
            y_full = torch.tensor(y_full, device=self.device).float().unsqueeze(1)
    
            eval_pos = self.X_.shape[0]
    
            prediction = transformer_predict(self.model[2], X_full, y_full, eval_pos,
                                             device=self.device,
                                             style=self.style,
                                             inference_mode=True,
                                             preprocess_transform='none' if self.no_preprocess_mode else 'mix',
                                             normalize_with_test=normalize_with_test,
                                             N_ensemble_configurations=self.N_ensemble_configurations,
                                             softmax_temperature=self.temperature,
                                             combine_preprocessing=self.combine_preprocessing,
                                             multiclass_decoder=self.multiclass_decoder,
                                             feature_shift_decoder=self.feature_shift_decoder,
                                             differentiable_hps_as_style=self.differentiable_hps_as_style
                                             , **get_params_from_config(self.c))
            prediction_, y_ = prediction.squeeze(0), y_full.squeeze(1).long()[eval_pos:]
    
            return prediction_.detach().cpu().numpy()
        
        except Exception as e:
           raise check_array(X)

I hope this assists! Noting that I knew there will be an error, I stupidly submitted my data without imputed it beforehand, which may appear to be a common occurrence for some users, the idea is to correctly reporting that to the users.

Cheers,

setup.py dependencies

setup.py seems to install more dependencies than should be necessary for this model to function. Would it make sense to instead have a tabpfn[benchmark] extra dependencies option akin to tabpfn[baselines]?

      install_requires=[
        'gpytorch>=1.5.0',
        'torch>=1.9.0',
        'scikit-learn>=0.24.2',
        'pyyaml>=5.4.1',
        'seaborn>=0.11.2',
        'xgboost>=1.4.0',
        'tqdm>=4.62.1',
        'numpy>=1.21.2',
        'openml>=0.12.2',
        'catboost>=0.26.1',
        'auto-sklearn>=0.14.5',
        'hyperopt>=0.2.5',
        'configspace>=0.4.21',
      ],

Currently it is unclear how to clone and run TabPFN isolated from source install without these dependencies.

Perhaps it would be ideal to instead have an entirely separate repository for benchmarking TabPFN so that CatBoost etc. have nothing to do with this repo. This would help a lot in terms of code cleanliness and separation of concerns.

Online/Continuous Learning Support

I am curious on whether TabPFN's performance has been tested in an online learning setting, where batch_size = 1. Checking out some papers found some other AutoML frameworks deteriorated in inference performance when concept drift set in. Would significant modifications need to be taken for TabPFN or given its fast training/inference would a simple retrain be feasible?

How to disable overwrite_warning=True without putting it in fit

Is there a way to disable this warning as I want to use TabPFN as a stacking estimator (using TPOT's easy stacking estimator functionality). I've tried a bunch of ways but I think the only is to disable this bs from TabPFN script.. Any advice on how to do that?

more nitpicks on end-to-end

Another nitpick on the end-to-end training and evaluation.
Currently the TabPFNClassifier is hard-coded to read from f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl' but I don't see any code that would write a model to that path.

verbose level

Hi team,

How do I set the verbose level of the classifier? Below message pops up every time I call the instance.

"Loading model that can be used for inference only
Using a Transformer with 25.82 M parameters"

Requirements

Dear authors,

Wonderful system ! I heard numerous transformers discussions using AutoML but no concrete systems, finally here is one. Interesting though. I am fairly new to AutoML as part of my Ph.D research and not in Transformers anyway but I'll keep an eye on this repository ✅

Brief of the issue

Surprisingly, your requirements include installing everything; yet, when I installed TabPFN using pip in my conda environment, the system would not function. I had to re-download (i.e, using PIP) PyTorch and/or AutoGluon for the setup to function. Aside from this brief observation, no console logs were requiring to download them just my instinct as the system was not working.

To reproduce, I reckon the following would be worthwhile.

  • Setup a brand-new OSX (I do work on a new OSX but that might also be the case on Windows/Linux);
  • Install Python 3-x and Conda
  • Install TabPFN
  • Run the basic recommended snippet from the README and see whether it does work

This procedure was not functioning on my end; as mentioned, I needed to redownload PyTorch and/or AutoGluon. Perhaps you simply need to resolve this quick issue for further users or indicate in the readme that PyTorch and/or AutoGluon must be downloaded manually.

Hope this help ! Otherwise great contribution,
Cheers,

batch size?

Hey!
So in PriorFittingCustomPrior.ipynb you first set a batch size of 64, and then overwrite it with 4, but the paper specifies it as 512, right?
I assume overwriting the 64 with 4 was accidental for adding the plotting code, but I'm not sure if I misunderstood the meaning of batch here or if the notebook just runs a different config than what's reported in the paper.

collaboration

Hi. I maintain a site where people launch fully autonomous algorithms and they compete. I'm curious about the application of tabpfn to univariate time-series (probably post generation of a few features, or as a way of providing a meta-model). It seems that the single pass could be very fast an effective in this setting. LMK if you think unleashing your algo in this fashion is interesting, and we can probably meet halfway in the middle of our APIs.
-Peter

Several for-loops in evaluation don't seem to do anything

There's some for-loops in the evaluation that only work if they run exactly once, which seems pretty confusing.
One is in https://github.com/automl/TabPFN/blob/main/tabpfn/RunFullDatasetAnalyses.ipynb, in "eval_methods".
If there's more than one did, it only returns the results of the last iteration. The code always calls it with only one did, so that's fine, but then maybe remove the for-loop?

The other one is here:
https://github.com/automl/TabPFN/blob/main/tabpfn/scripts/baseline_prediction_interface.py

That one has a TQDM attached, but it returns in the first iteration. Again, I think it's only called with lists of length one, but maybe then removing the loop would be better?

Btw, I rewrote eval_methods to take datasets, not "dids" since the dids rely on global state about how things are numbered and add some extra indirection. If you want a PR for the refactor, let me know.

NameError: name 'Module' is not defined

Hello,
I wanted to test your library on some of my data but I'm having trouble when I try to import the classifier with:

from tabpfn.scripts.transformer_prediction_interface import TabPFNClassifier

The full Traceback is:

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In [1], line 1
----> 1 from tabpfn.scripts.transformer_prediction_interface import TabPFNClassifier

File ~\OneDrive\Documents\PhD Jesse\TabPFN-main\tabpfn\scripts\transformer_prediction_interface.py:20
     18 from sklearn.utils import column_or_1d
     19 from pathlib import Path
---> 20 from tabpfn.scripts.model_builder import load_model
     21 import os
     22 import pickle

File ~\OneDrive\Documents\PhD Jesse\TabPFN-main\tabpfn\scripts\model_builder.py:2
      1 from functools import partial
----> 2 from tabpfn.train import train, Losses
      3 import tabpfn.priors as priors
      4 import tabpfn.encoders as encoders

File ~\OneDrive\Documents\PhD Jesse\TabPFN-main\tabpfn\train.py:14
     11 from torch import nn
     13 import tabpfn.utils as utils
---> 14 from tabpfn.transformer import TransformerModel
     15 from tabpfn.utils import get_cosine_schedule_with_warmup, get_openai_lr, StoreDictKeyPair, get_weighted_single_eval_pos_sampler, get_uniform_single_eval_pos_sampler
     16 import tabpfn.priors as priors

File ~\OneDrive\Documents\PhD Jesse\TabPFN-main\tabpfn\transformer.py:9
      6 from torch import Tensor
      7 from torch.nn import Module, TransformerEncoder
----> 9 from tabpfn.layer import TransformerEncoderLayer, _get_activation_fn
     10 from tabpfn.utils import SeqBN, bool_mask_to_att_mask
     14 class TransformerModel(nn.Module):

File ~\OneDrive\Documents\PhD Jesse\TabPFN-main\tabpfn\layer.py:10
      5 from torch.nn.modules.transformer import _get_activation_fn
      7 from torch.utils.checkpoint import checkpoint
---> 10 class TransformerEncoderLayer(Module):
     11     r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
     12     This standard encoder layer is based on the paper "Attention Is All You Need".
     13     Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
   (...)
     36         >>> out = encoder_layer(src)
     37     """
     38     __constants__ = ['batch_first']

NameError: name 'Module' is not defined

If it helps: the list of installed pip packages is:

Package                            Version
---------------------------------- -----------------
alabaster                          0.7.12
anaconda-client                    1.7.2
anaconda-navigator                 2.0.3
anaconda-project                   0.9.1
anyio                              2.2.0
appdirs                            1.4.4
argh                               0.26.2
argon2-cffi                        20.1.0
asn1crypto                         1.4.0
astroid                            2.5
astropy                            4.2.1
async-generator                    1.10
atomicwrites                       1.4.0
attrs                              20.3.0
autopep8                           1.5.6
Babel                              2.9.0
backcall                           0.2.0
backports.functools-lru-cache      1.6.4
backports.shutil-get-terminal-size 1.0.0
backports.tempfile                 1.0
backports.weakref                  1.0.post1
bcrypt                             3.2.0
beautifulsoup4                     4.9.3
bitarray                           1.9.2
bkcharts                           0.2
black                              19.10b0
bleach                             3.3.0
bokeh                              2.3.2
boto                               2.49.0
Bottleneck                         1.3.2
brotlipy                           0.7.0
certifi                            2020.12.5
cffi                               1.14.5
chardet                            4.0.0
click                              7.1.2
cloudpickle                        1.6.0
clyent                             1.2.2
colorama                           0.4.4
comtypes                           1.1.9
conda                              4.13.0
conda-build                        3.21.4
conda-content-trust                0+unknown
conda-package-handling             1.8.1
conda-repo-cli                     1.0.4
conda-token                        0.3.0
conda-verify                       3.4.2
ConfigSpace                        0.6.0
contextlib2                        0.6.0.post1
cryptography                       3.4.7
cycler                             0.10.0
Cython                             0.29.23
cytoolz                            0.11.0
dask                               2021.4.0
decorator                          5.0.6
defusedxml                         0.7.1
diff-match-patch                   20200713
distributed                        2021.4.0
docutils                           0.17
entrypoints                        0.3
et-xmlfile                         1.0.1
fastcache                          1.1.0
filelock                           3.0.12
flake8                             3.9.0
Flask                              1.1.2
fsspec                             0.9.0
future                             0.18.2
gevent                             21.1.2
glob2                              0.7
gpytorch                           1.9.0
greenlet                           1.0.0
h5py                               2.10.0
HeapDict                           1.0.1
html5lib                           1.1
hyperopt                           0.2.7
idna                               2.10
imagecodecs                        2021.3.31
imageio                            2.9.0
imagesize                          1.2.0
importlib-metadata                 3.10.0
iniconfig                          1.1.1
intervaltree                       3.1.0
ipykernel                          5.3.4
ipython                            7.22.0
ipython-genutils                   0.2.0
ipywidgets                         7.6.3
isort                              5.8.0
itsdangerous                       1.1.0
jdcal                              1.4.1
jedi                               0.17.2
Jinja2                             2.11.3
joblib                             1.0.1
json5                              0.9.5
jsonschema                         3.2.0
jupyter                            1.0.0
jupyter-client                     6.1.12
jupyter-console                    6.4.0
jupyter-core                       4.7.1
jupyter-packaging                  0.7.12
jupyter-server                     1.4.1
jupyterlab                         3.0.14
jupyterlab-pygments                0.1.2
jupyterlab-server                  2.4.0
jupyterlab-widgets                 1.0.0
keyring                            22.3.0
kiwisolver                         1.3.1
lazy-object-proxy                  1.6.0
liac-arff                          2.5.0
libarchive-c                       2.9
linear-operator                    0.1.1
llvmlite                           0.36.0
locket                             0.2.1
lxml                               4.6.3
MarkupSafe                         1.1.1
matplotlib                         3.3.4
mccabe                             0.6.1
menuinst                           1.4.16
minio                              7.1.12
mistune                            0.8.4
mkl-fft                            1.3.0
mkl-random                         1.2.1
mkl-service                        2.3.0
mock                               4.0.3
more-itertools                     8.7.0
mpmath                             1.2.1
msgpack                            1.0.2
multipledispatch                   0.6.0
mypy-extensions                    0.4.3
navigator-updater                  0.2.1
nbclassic                          0.2.6
nbclient                           0.5.3
nbconvert                          6.0.7
nbformat                           5.1.3
nest-asyncio                       1.5.1
networkx                           2.5
nltk                               3.6.1
nose                               1.3.7
notebook                           6.3.0
numba                              0.53.1
numexpr                            2.7.3
numpy                              1.22.4
numpydoc                           1.1.0
olefile                            0.46
openml                             0.12.2
openpyxl                           3.0.7
packaging                          20.9
pandas                             1.2.4
pandocfilters                      1.4.3
paramiko                           2.7.2
parso                              0.7.0
partd                              1.2.0
path                               15.1.2
pathlib2                           2.3.5
pathspec                           0.7.0
patsy                              0.5.1
pep8                               1.7.1
pexpect                            4.8.0
pickleshare                        0.7.5
Pillow                             8.2.0
pip                                21.0.1
pkginfo                            1.7.0
pluggy                             0.13.1
ply                                3.11
prometheus-client                  0.10.1
prompt-toolkit                     3.0.17
psutil                             5.8.0
ptyprocess                         0.7.0
py                                 1.10.0
py4j                               0.10.9.7
pyarrow                            10.0.0
pycodestyle                        2.6.0
pycosat                            0.6.3
pycparser                          2.20
pycurl                             7.43.0.6
pydocstyle                         6.0.0
pyerfa                             1.7.3
pyflakes                           2.2.0
Pygments                           2.8.1
pylint                             2.7.4
pyls-black                         0.4.6
pyls-spyder                        0.3.2
PyNaCl                             1.4.0
pyodbc                             4.0.0-unsupported
pyOpenSSL                          20.0.1
pyparsing                          2.4.7
pyreadline                         2.1
pyrsistent                         0.17.3
PySocks                            1.7.1
pytest                             6.2.3
python-dateutil                    2.8.1
python-jsonrpc-server              0.4.0
python-language-server             0.36.2
pytz                               2021.1
PyWavelets                         1.1.1
pywin32                            227
pywin32-ctypes                     0.2.0
pywinpty                           0.5.7
PyYAML                             5.4.1
pyzmq                              20.0.0
QDarkStyle                         2.8.1
QtAwesome                          1.0.2
qtconsole                          5.0.3
QtPy                               1.9.0
regex                              2021.4.4
requests                           2.25.1
rope                               0.18.0
Rtree                              0.9.7
ruamel-yaml-conda                  0.15.100
scikit-image                       0.18.1
scikit-learn                       1.1.3
scipy                              1.6.2
seaborn                            0.12.1
Send2Trash                         1.5.0
setuptools                         58.2.0
simplegeneric                      0.8.1
singledispatch                     0.0.0
sip                                4.19.13
six                                1.15.0
sniffio                            1.2.0
snowballstemmer                    2.1.0
sortedcollections                  2.1.0
sortedcontainers                   2.3.0
soupsieve                          2.2.1
Sphinx                             4.0.1
sphinxcontrib-applehelp            1.0.2
sphinxcontrib-devhelp              1.0.2
sphinxcontrib-htmlhelp             1.0.3
sphinxcontrib-jsmath               1.0.1
sphinxcontrib-qthelp               1.0.3
sphinxcontrib-serializinghtml      1.1.4
sphinxcontrib-websupport           1.2.4
spyder                             4.2.5
spyder-kernels                     1.10.2
SQLAlchemy                         1.4.7
statsmodels                        0.12.2
sympy                              1.8
tables                             3.6.1
tabpfn                             0.1.5
tblib                              1.7.0
terminado                          0.9.4
testpath                           0.4.4
textdistance                       4.2.1
threadpoolctl                      2.1.0
three-merge                        0.1.1
tifffile                           2021.4.8
toml                               0.10.2
toolz                              0.11.1
torch                              1.13.0
tornado                            6.1
tqdm                               4.64.1
traitlets                          5.0.5
twine                              3.4.2
typed-ast                          1.4.2
typing-extensions                  3.7.4.3
ujson                              4.0.2
unicodecsv                         0.14.1
urllib3                            1.26.4
watchdog                           1.0.2
wcwidth                            0.2.5
webencodings                       0.5.1
Werkzeug                           1.0.1
wheel                              0.37.0
widgetsnbextension                 3.5.1
win-inet-pton                      1.1.0
win-unicode-console                0.5
wincertstore                       0.2
wrapt                              1.12.1
xlrd                               2.0.1
XlsxWriter                         1.3.8
xlwings                            0.23.0
xlwt                               1.3.0
xmltodict                          0.12.0
yapf                               0.31.0
zict                               2.0.0
zipp                               3.4.1
zope.event                         4.5.0
zope.interface                     5.3.0

Maybe I'm missing something simple but I can't find out why its throwing this error...

Problems saving a trained model

Hi, I have been playing with your model with frankly some impressive results. But I struggle when trying to save a trained model. I tried to use pickle but I got an error. It seems that there is a component that is not serializable.

Is there other way to save a trained model?

Thanks.

By the way, this is the exception I get:

AttributeError: Can't pickle local object 'load_model.<locals>.<lambda>'

SCM vs. BNN

Thank you very much for this paper and sharing the code!

A small question please:
In the paper you show the results of the SCM prior and the BNN prior.
Which configurations in the code please would correspond to these results?
(e.g., I assume the 'mlp.py' prior gives the SCM, is this right?
which hyper-parameters can I use please to replicate the SCM-only results? which file gives the BNN prior?).

Thanks once again!
Niv.

GPU Warning on breast cancer example from git

When I run the code on the home page, https://github.com/automl/TabPFN, I get this warning:

/opt/conda/lib/python3.7/site-packages/torch/autocast_mode.py:162: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')

This is even though cpu is being sent as the device in the classifier instantiation. The fit/predict also seemed rather slow.

Regression?

Thank you very much for your work!
Also, can you use TabPFN for regression tasks? And if not, do you intend to provide it with such a capability soon?

Thank you!

Licensing

Hey folks!
This is amazing work!

There's currently no license in this repo. Is that on purpose? If there's a chance to use Apache, that would be great!

How to include TabPFN in TPOT's custom dictionary

I try to include TABPFN to TPOT's custom dictionary like so:

'tabpfn.TabPFNClassifier':{
    'device':['cpu'], 
    'N_ensemble_configurations': list(range(2, 39))
},

But I just get warning:

Loading model that can be used for inference only
Using a Transformer with 25.82 M parameters

How can I fix this?

Time series regression?

Hi, I am also very looking forward to use TabPFN for regression tasks.
However, I am also wondering if TabPFN can be modified to do time series prediction?
My goal is to do time series regression in finance.
Thanks!

AttributeError: 'TabPFNClassifier' object has no attribute 'no_grad'

Hi!

Thanks for a fantastic tool.

I fitted some models a while ago, and now when trying to deploy them on another computer I get this error in the predict_proba method:

AttributeError: 'TabPFNClassifier' object has no attribute 'no_grad'

May I know when was the self.no_grad attribute introduced? I would like to fall back to the old tabpfn version in this particular case

What's the checkpoint file in models_diff? And why rename epoch42 to epoch100 in the loop?

Thanks for your excellent work and I'm confused that above. And the code snippet below.

    def get_file(e):
        """
        Returns the different paths of model_file, model_path and results_file
        """
        model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt'
        model_path = os.path.join(base_path, model_file)
        # print('Evaluate ', model_path)
        results_file = os.path.join(base_path,
                                    f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl')
        return model_file, model_path, results_file

    def check_file(e):
        model_file, model_path, results_file = get_file(e)
        if not Path(model_path).is_file():  # or Path(results_file).is_file():
            print('We have to download the TabPFN, as there is no checkpoint at ', model_path)
            print('It has about 100MB, so this might take a moment.')
            import requests
            url = 'https://github.com/automl/TabPFN/raw/main/tabpfn/models_diff/prior_diff_real_checkpoint_n_0_epoch_42.cpkt'
            r = requests.get(url, allow_redirects=True)
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            open(model_path, 'wb').write(r.content)
        return model_file, model_path, results_file

    model_file = None
    if e == -1:
        for e_ in range(100, -1, -1):
            model_file_, model_path_, results_file_ = check_file(e_)
            if model_file_ is not None:
                e = e_
                model_file, model_path, results_file = model_file_, model_path_, results_file_
                break

excessive print statements

The implementation contains various print statements, some uncommented, some not.
It would be nice to have a verbose-statement that controls the level of output.

silly question: learning curves

This is probably a stupid question, but I was wondering if/how you stored learning curves during training. I can't find anything in the code that records the losses. Maybe there's some built-in torch mechanism that I'm unaware of?

Obviously I can build a list myself and serialize that but I figure you must have done it some way as well...

Hyperparameters for XGBoost models in Table 2

The HPO functionality is included in the codebase but I am unable to find the final hyperparameters used for each dataset in Table 2.

These would be very valuable - specifically the XGBoost parameters.

Are you able to provide these?

Feature importance

Hi, amazing job! I was testing your model on my data and I was wondering if there is any chance to get feature importance scores (similarly to other ML models). Is it possible to get them in some way from the model as it is or you plan to add the feature importance attribute in future?

Thank you!

train_mixed_precision is not passed in model_builder?

It looks to me like train_mixed_precision is not passed to train in model_builder.get_model so no matter what the config_sample says, there's no mixed precision from what I can tell. Doing mixed precision training seems to give a speedup of about 30%.

Random patches for naive scaling?

Hi,
Thanks, for this very interesting and unorthodox approach.
I wonder if you've have tried to scale it up with random patches:
https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.BaggingClassifier.html
basically, as base estimator TabPFN, max_samples=1000, max_features=100, n_estimators could be tuned with e.g. early stopping, note taht n_jobs need to be 1 or google colab crashes.
I know it is naive and mundane approach but could be good enough/competitve for some cases.
Btw. even with data within constraints, subsampling of features could help a little with uninformative ones.

How to suppress message?

When I'm trying to train tabpfn model, the notebook shows a message as following:

"Loading model that can be used for inference only
Using a Transformer with 25.82 M parameters"

What can I do to suppress display of this message?

F1 score is 0

I trained on a classification dataset for a financial task, and the results showed that all predictions were categorized as class 1 with an F1 score of 0. For reference, using XGBOOST yielded an accuracy of approximately 0.55 and an F1 score of 0.33. What could be the reason for this phenomenon?

Can we use TabPFN for regression tasks?

Hi TabPFN developer,

Thanks for this amazing algorithm! The algorithm works well for classification tasks.

But I am wondering how we modify the script to perform regression tasks?

I assume we only need to change the output layer and loss function. Could you please some instructions.

Many thanks in advance!

Difficulties with reproducing the training

Hey 👋

Thanks for the awesome project! I'm very excited with the universal tabular model idea and your results are quite impressive!

I was trying to reproduce the training, but got into some troubles. I guess that the train.py file is not the main entry-point into training the TabPFN transfomer, I found some configs in the PriorFittingCustomPrior.ipynb notebook along with the get_model function, which appears to train the model, but there are still some things I couldn't understand.

Could you please provide some details about the config which was used to produce the checkpoint?

Could you also provide additional details on priors:

  • What is prior_bag (I guess it sequentially samples or something?)?
  • What is fast_gp prior and how it is used in the context of the tabular prior (there are no mentions of the Gaussian Processes in the TabPFN paper, only in the original PFN paper)?
  • Am I understanding correctly that the DifferentiableHyperparameter code is used only to sample dataset parameters from ranges set in the configs (and described in table 5 of the appendix)?

Synthetic Data Generator Issue

I'm trying to play around with the synthetic data generator, but I'm running into an issue. I tried to run train.py but it gives me the following error:

Traceback (most recent call last): File "train.py", line 382, in <module> y_encoder_generator=y_encoder_generator, pos_encoder_generator=pos_encoder_generator, **args.__dict__) File "train.py", line 216, in train train_epoch() File "train.py", line 135, in train_epoch for batch, (data, targets, single_eval_pos) in enumerate(dl): File "TabPFN/tabpfn/priors/utils.py", line 46, in <genexpr> return iter(self.gbm(**self.get_batch_kwargs, epoch=self.epoch_count - 1, model=self.model) for _ in range(self.num_steps)) File "/TabPFN/tabpfn/priors/utils.py", line 33, in gbm batch = get_batch_method_(*args, **kwargs) File "anaconda3/envs/tabpfn/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context return func(*args, **kwargs) TypeError: get_batch() missing 1 required positional argument: 'get_batch'

When I tried to fix this by passing in the get_batch method, I get a recursion error that states the program reached maximum depth. Can anyone help me here?

Version issues for training

I've tried running the PriorFittingCustomPrior.ipynb and run into some difficulties.
It seems lightgbm is getting imported, but it's not part of the pyproject.toml. Also, seaborn '0.12.2' raises an error when plotting:

ValueError: The following variable cannot be assigned with wide-form data: `hue`

It would be awesome to get a conda environment with a working config, I also didn't see the python3.7 requirement at first, since it's only mentioned in the requirements.txt.

running on 'glass' gives index error in cross_entropy

df = valid_datasets[53]  # glass, id 41
print(ds[0], ds[1].shape)
xs, ys = ds[1].clone(), ds[2].clone()
eval_position = xs.shape[0] // 2
train_xs, train_ys = xs[0:eval_position], ys[0:eval_position]
test_xs, test_ys = xs[eval_position:], ys[eval_position:]
classifier = TabPFNClassifier()
print(classifier)
classifier.fit(train_xs, train_ys)
prediction_ = classifier.predict_proba(test_xs)
roc, ce = tabular_metrics.auc_metric(test_ys, prediction_), tabular_metrics.cross_entropy(test_ys, prediction_)
print('AUC', float(roc), 'Cross Entropy', float(ce))

There seems to be 7 classes but prediction_ only has 6 columns.

ValueError: ('The number of features for this classifier is restricted to ', 100)

I am trying to build a text classifier witt TabPFN using sklearn Pipeline
the data is this

X = data[['text']].values
y = data[['category']].values

X_train, X_test, y_train, y_test = train_test_split (X, y, test_size=0.3, random_state=42)

clf_tap = Pipeline (steps=[
    ('vect', CountVectorizer()),
    ('tfidf', TfidfTransformer()), 
    ("classifier", TabPFNClassifier(device='cpu'))])

start_tap = time.time ()
clf_tap.fit (X_train.ravel(), y_train.ravel())
print ("--- %s seconds took with tap classifier to train ---" % (time.time () - start_tap))

the error

ValueError                                Traceback (most recent call last)
[<ipython-input-12-b7fd890448d4>](https://localhost:8080/#) in <module>
     34 
     35 start_tap = time.time ()
---> 36 clf_tap.fit (X_train.ravel(), y_train.ravel())
     37 print ("--- %s seconds took with tap classifier to train ---" % (time.time () - start_tap))

1 frames
[/usr/local/lib/python3.7/dist-packages/tabpfn/scripts/transformer_prediction_interface.py](https://localhost:8080/#) in fit(self, X, y)
    168 
    169         if X.shape[1] > self.max_num_features:
--> 170             raise ValueError("The number of features for this classifier is restricted to ", self.max_num_features)
    171         if len(np.unique(y)) > self.max_num_classes:
    172             raise ValueError("The number of classes for this classifier is restricted to ", self.max_num_classes)

ValueError: ('The number of features for this classifier is restricted to ', 100)

Take forever to load pretrained checkpoint on Macbook M1

Currently I am using python 3.9 on Macbook M1. When I initiate the TabPFN, it download the pretrained checkpoint but then take forever to load the model (even if with n_ensemble_config small). Is there anyone had the same issue and fix it? it doens't output anything so I am not sure where to fix it.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.