Git Product home page Git Product logo

rtfm's Introduction

rtfm is a Python library for research on tabular foundation models (RTFM).

rtfm is the library used to train TabuLa-8B, a state-of-the-art model for zero- and few-shot tabular data prediction described in our paper "Large Scale Transfer Learning for Tabular Data via Language Modeling".

few-shot results curve

You can also use rtfm to train your own tabular language models.

rtfm has been used to train 7B- and 8B-parameter Llama 2 and Llama 3 language models, and supports advanced and efficient training methodologies such as fully sharded data parallel (FSDP), multinode training, and 16-bit training with bf16. In the future, we plan to support additional base language models and larger scales; currently, support for larger Llama models exists but should be considered experimental. We do not currently support other (non-Llama) language models.

Environment setup

We recommend use of the provided conda environment. You can set it up with:

conda env create -f environment.yml
pip install --no-deps git+https://github.com/mlfoundations/tableshift.git

Quickstart - Inference

If you want to interactively explore the model or want to try it on your own, unlabelled data, the best way to do this is by using the inference.ipynb notebook in notebooks. This notebook shows how to create simple DataFrames and use them for inference.

The notebook above is our recommended default for users interested in trying out TabuLa-8B. For more fine-grained control over your inference (e.g. changing the system prompt used at inference time), you can use inference_utils.infer_on_example().

Quickstart - Training

Once you have set up your environment, you can train models using the Python script at scripts/finetune.py. As an example, to conduct a training run with a small toy model, run:

python -m rtfm.finetune \
  --train-task-file "./sampledata/v6.0.3-serialized/test/test-files.txt" \
  --eval-task-file "./sampledata/v6.0.3-serialized/train/train-files.txt" \
  --run_validation "False" \
  --use_wandb "False" \
  --warmup_steps 1 \
  --num_workers_dataloader 0 \
  --max_steps 10 \
  --model_name "yujiepan/llama-2-tiny-random" \
  --save_checkpoint_root_dir "checkpoints" \
  --run_name "my_model_dir" \
  --save_model \
  --save_optimizer

This will conduct a short training run and save the model and optimizer state to checkpoints/my_model_dir-llama-2-tiny-random. To train a Llama3-8B model instead of the toy model in this example, replace yujiepan/llama-2-tiny-random with meta-llama/Meta-Llama-3-8B.

Quickstart - Evaluation

To evaluate a model, we recommend using the Python script at scripts/evaluate_checkpoint.py. In order to evaluate against any dataset, we recommend first preparing the dataset using the provided script scripts/utils/prepare_csv_for_eval.py:

python scripts/utils/prepare_csv_for_eval.py --output_dir ./eval_tasks/my_task

prepare_csv_for_eval.py prescribes how the data in the CSV is serialized at evaluation time by defining a FeatureList and YAML file describing the features and prediction task, respectively. By default, prepare_csv_for_eval.py processes data the same way as the evaluations for TabuLa-8B. If desired, you can write a custom script to control the creation of the FeatureList and YAML files to change how data is serialized.

Once the data is prepared, run evaluation via:

USER_CONFIG_DIR=./eval_tasks/ \
  python scripts/evaluate_checkpoint.py \
  --eval-task-names "my_task" \
  --model_name "yujiepan/llama-2-tiny-random" \
  --resume "checkpoints/my_model_dir-llama-2-tiny-random" \
  --eval_max_samples 16 \
  --context_length 2048 \
  --pack_samples "False" \
  --num_shots 1 \
  --outfile "tmp.csv"

This will write a file to tmp.csv containing evaluation results.

If you want to evaluate the pretrained released TabuLa-8B model, set --model_name to mlfoundations/tabula-8b and remove the --resume flag or set it to the empty string.

Environment Setup and Test

You can create an environment to reproduce or run experiments via rtfm by using conda:

conda env create -f environment.yml

Once you've set up your environment, you need to add your Hugging Face token in order to access the LLama weights. To do this, you can run

huggingface-cli login

or manually set the token via

export HF_TOKEN=your_token

To test your setup (on any machine, no GPUs required), you can run the following command:

sh scripts/tests/train_tiny_local_test.sh

End-To-End Training Example

This section gives an example of how to train a model from a set of parquet files.

1. Prepare training data.

The model expects sets of serialized records stored in .tar files, which are in webdataset format. To serialize data, we provide the script serialize_interleave_and_shuffle.py (located at rtfm/pipelines/serialize_interleave_and_shuffle.py) to serialize a set of parquet files:

python -m rtfm.pipelines.serialize_interleave_and_shuffle \
    --input-dir /glob/containing/parquet/files/ \
    --output-dir ./serialized/v6.0.3/ \
    --max_tables 64 \
    --serializer_cls "BasicSerializerV2"

The recommended way to store training data is in a newline-delimited list of webdataset files. The above command will automatically generate sets of training, validation (train-eval), and test files, where the train-eval split comprises unseen rows from tables in the training split, and the test split comprises only unseen tables.

Using data hosted on S3 (recommended)

Some datasets may be too large to store on disk during training. rtfm supports using files stored on AWS S3. To use files hosted on S3, you need to move the training data there, and update the text files produced by rtfm/pipelines/serialize_interleave_and_shuffle.py to point to the correct location. You can do this with sed, for example, the command below will replace the local training location with the s3 path for all lines in a text file:

sed 's|/path/to/sampledata/|s3://rtfm-hub/tablib/serialized/v0-testdata/|g' traineval-files.txt > traineval-files-s3.txt

Using local training data

If you plan to use local data, you can use the files produced as the output of serialize_interleave_and_shuffle.py (train-files.txt, traineval-files.txt, test-files.txt).

2. Launch a training job.

The recommended way to launch a training job is via finetune.py. You can do this, for example, via:

python -m rtfm.finetune \
  --train-task-file "./sampledata/v6.0.3-serialized/test/test-files.txt" \
  --eval-task-file "./sampledata/v6.0.3-serialized/train/train-files.txt" \
  --run_validation "False" \
  --use_wandb "False" \
  --warmup_steps 1 \
  --num_workers_dataloader 0 \
  --max_steps 10 \
  --model_name "yujiepan/llama-2-tiny-random" \
  --save_checkpoint_root_dir "checkpoints" \
  --run_name "my_model_dir" \
  --save_model \
  --save_optimizer

See finetune.py and the associated configuration classes in rtfm.configs and rtfm.arguments for more options to control the details of training.

Additional Resources

Some additional resources relevant to RTFM:

rtfm's People

Contributors

jpgard avatar

Stargazers

 avatar Juan Carlos Perdomo avatar Trinadh Gupta avatar Jon Chun avatar Aadit Kamat avatar Jeremy Ma avatar Ivan Rubachev avatar Yongyi Zang avatar  avatar  avatar Reachsak Ly avatar Mike avatar smellslikeml avatar Unchun Yang avatar Eduardo Dadalto avatar Lucas Patel avatar Jay avatar maolong Li avatar Lu Ming avatar

Watchers

Mike avatar Ludwig Schmidt avatar  avatar Vivek Ramanujan avatar Yair Carmon avatar

rtfm's Issues

Not usable absolute paths

In

https://github.com/mlfoundations/rtfm/blob/main/sampledata/v6.0.3-serialized/test/test-files.txt

there are absolute paths specific to one user like

/Users/jpgard/Documents/github/tablm/sampledata/v6.0.3-serialized/test/test-000002.tar

however they should be changed to relative paths, e.g.

rtfm/sampledata/v6.0.3-serialized/test/test-000002.tar

as it is also done in

https://github.com/mlfoundations/rtfm/blob/main/sampledata/v6.0.3-serialized/train/train-files.txt

But since all other paths are relative to rtfm it should actually be

sampledata/v6.0.3-serialized/test/test-000002.tar

and then also changed in the train files.

RuntimeError: Invalid device string: 'cuda:None'

During training (Tesla V100-PCIE-16GB) I get the following error

Train:   0%|                                                                                                                                           | 0/10 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/anaconda/envs/rtfm/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/anaconda/envs/rtfm/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/dev-medekm-gpu/code/Users/michael.medek/rtfm/rtfm/finetune.py", line 451, in <module>
    main(
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/dev-medekm-gpu/code/Users/michael.medek/rtfm/rtfm/finetune.py", line 408, in main
    results = train(
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/dev-medekm-gpu/code/Users/michael.medek/rtfm/rtfm/train_utils.py", line 274, in train
    batch[key] = batch[key].to(f"cuda:{local_rank}")
RuntimeError: Invalid device string: 'cuda:None'
Train:   0%| 

Which traces to here

batch[key] = batch[key].to(f"cuda:{local_rank}")

where local_rank is None, thus Invalid device string: 'cuda:None'. How is this supposed to work? The default of the function is local_rank=None which should be invalid, since it must be int, right? In evaluate() there is only local_rank: int.

By adding

local_rank = 0
rank = 0
print("WARNING! Overwriting local_rank and rank to 0!")

this issue is worked around.

Using wrong files for training

In the readme

rtfm/README.md

Line 41 in 9884a6b

--train-task-file "./sampledata/v6.0.3-serialized/test/test-files.txt" \

there is

...
  --train-task-file "./sampledata/v6.0.3-serialized/test/test-files.txt" \
  --eval-task-file "./sampledata/v6.0.3-serialized/train/train-files.txt" \
...

however --train-task-file should use train/train-files.txt and not swapped as it is currently, right?

TrainConfig does not contain serializer_cls

Hello! I was trying to run the inference.ipynb notebook, and I get an AttributeError in the first cell, because the 'TrainConfig' object has no attribute ‘serializer_cls'.
Screenshot 2024-08-05 at 13 05 42

What is the correct way to run serialization?

Best regards,
Anna Badalyan

Inconsistent parameters

The readme states

rtfm/README.md

Line 67 in 9884a6b

python scripts/utils/prepare_csv_for_eval.py --output_dir ./eval_tasks/my_task

python scripts/utils/prepare_csv_for_eval.py --output_dir ./eval_tasks/my_task

which gives

ERROR: The function received no value for the required argument: target_colname
Usage: prepare_csv_for_eval.py CSV OUT_DIR TARGET_COLNAME TO_REGRESSION

so output_dir should be out_dir and to_regression is missing, but mainly target_colname is missing. However if I provide it it gets ignored in

generate_files_from_csv(csv, out_dir, to_regression=to_regression)

And then overwritten and inferred here

target_colname = df.columns[-1]

target_colname = df.columns[-1]

To solve this, target_colname should be optional and only if not given, it should be inferred from df.columns[-1]. Also to_regression should have a default to_regression: bool = False. What do you think?

Deprecated readme path of evaluate_checkpoint.py

In the readme there still is

python scripts/utils/prepare_csv_for_eval.py ...

however the file was moved and it should be now

python rtfm/evaluation/evaluate_checkpoint.py

or

python -m rtfm.evaluation.evaluate_checkpoint ...

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.