Git Product home page Git Product logo

deep-symbolic-mathematics / multimodal-math-pretraining Goto Github PK

View Code? Open in Web Editor NEW
35.0 3.0 2.0 985 KB

[ICLR 2024 Spotlight] This is the official code for the paper "SNIP: Bridging Mathematical Symbolic and Numeric Realms with Unified Pre-training"

Home Page: https://openreview.net/forum?id=KZSEgJGPxu

License: MIT License

Python 98.79% Shell 1.21%
ai4math ai4science deep-learning multi-modal multi-modal-learning representation-learning symbolic-math symbolic-regression transformers

multimodal-math-pretraining's Introduction

SNIP: A Multimodal Symbolic-Numeric Pretraining for Math (MathCLIP)

Official Implementation of ICLR 2024 Spotlight paper SNIP: Bridging Mathematical Symbolic and Numeric Realms with Unified Pre-training.

Paper | Models | Data | Code

Overview

Inspired by the great performance of CLIP in vision-language representation learning, we introduce a multi-modal pre-training model for symbolic mathematics, known as SNIP for Symbolic-Numeric Integrated Pre-training, which emphasizes the significance of numeric-augmented representations in math representation learning.


SNIP: A multi-modal transformer model that connects symbolic math equations with numeric data representations using contrastive learning

Installation

The code requires dependencies specified in environment.yml. Please follow the relevant libraries to install or run:

conda env create -f environment.yml

This library requires python>3.7

Pretrained Models

We've released two pretrained SNIP models, each designed for different types of analysis. Download them here. You'll find:

  • SNIP-10dmax: This model handles up to 10-dimensional inputs. More info in Section 5 and Appendix D p.3 of paper.

  • SNIP-1d-normalized: This model is for 1-dimensional inputs with normalized targets, great for focusing on function patterns. More details in Section 4 and Appendix D of paper.

To use them, create a weights/ folder in your project, download the checkpoints there, and use the --reload_model parameter with the model path, like --reload_model ./weights/snip-1d-normalized.pth."

Pretraining Data Generation

For pretraining, we generate synthetic data of (symbolic, numeric) pairs for math functions, based on method from SymbolicMathematics. Each pair includes data points $(x,y)$ and a math function $f$ such that $y=f(x)$. See generate_datapoints function here for more info. You can also adjust data generation settings here.

The data is generated on-the-fly during training, but if you want to create and analyze it beforehand, use run_export_data.sh:

python train.py --export_data True --dump_path ./dump --max_input_dimension 10

Your exported data will be saved in the data.prefix file.

SNIP Pretraining

All training settings for SNIP are in parsers.py. SNIP uses Transformer encoders for both symbolic and numeric heads, which you can find in the encoder_f and encoder_y modules here. For information on contrastive learning and training, look at the trainer file. Here's how you can start the training:

python train.py --loss_type CLIP \
                --batch_size 256 \
                --dump_path ./dump \
                --max_input_dimension 10 \
                --exp_id run1-10d \
                --lr 4e-5 \
                --latent_dim 512 \
                --save_periodic 10

Feel free to adjust training and data settings in parsers.py and environment.py under snip/envs/. After running the command, the model trained for every 10 (save_periodic) epochs is saved in dump/ path.

Using SNIP for Cross-modal Property Prediction

Here we have provided code to test SNIP representations for the cross-modal symbolic-to-numeric property prediction tasks, meaning that in these tasks, the input is the symbolic mathematical equation and the label is the propery defined based on numeric data observations.

Data Generation

To try it out, start by generating data. For instance, to generate 10k training examples for the Non-Convexity Ratio (NCR) prediction task (as explained in paper), use this command:

python train.py --export_data True --is_proppred True --property_type ncr --dump_path ./dump --max_input_dimension 1 --n_steps_per_epoch 625  --exp_name data --exp_id ncr

This saves data for ncr property in dump/data/ncr/. To generate data for other properties, just change the --property_type parameter.

Training

For this task, we use a Transformer encoder architecture (to encode symbolic equation inputs), followed by a regression predictor head (to predict the property). Training is done using Mean Squared Error (MSE) loss. Following are the commands for training different model variants defined in Sec 4 of paper.

Supervised Model (without Pretrining):

python train.py --is_proppred True \
                --property_type ncr \
                --reload_data functions,dump/data/ncr/train.prefix,dump/data/ncr/train.prefix, \
                --normalize_y True \
                --batch_size 64 \
                --dump_path ./dump \
                --max_input_dimension 1 \
                --exp_name NCR_pred \
                --exp_id run1 \
                --lr 1e-5 \
                --latent_dim 512 \
                --save_periodic 10

SNIP Encoder (frozen):

python train.py --reload_model ./weights/snip-1d-normalized.pth --freeze_encoder True [other parameters] 

SNIP Encoder (finetune):

python train.py --reload_model ./weights/snip-1d-normalized.pth --freeze_encoder False [other parameters] 

With these commands, the model saves automatically every 10 epochs. To use SNIP's encoder, you should activate --reload_model parameter with the path of model weights. You can also freeze the encoder with --freeze_encoder True.

Inference

To test how well your models perform for each property prediction task, use the run_eval_proppred.sh script. For example, if you want to test the NCR property task, use this command:

python eval_proppred.py --is_proppred True \
                        --property_type ncr \
                        --reload_model dump/NCR/model.pth \
                        --reload_data functions,dump/data/ncr/test.prefix,dump/data/ncr/test.prefix,

This command will use the --reload_model parameter to load the weights of your trained model and test it against the dataset specified in the --reload_data path.

Using SNIP for Symbolic Regression

To use SNIP for more complex tasks such as Symbolic Regression (uncovering symbolic math equations from data: numeric-to-symbolic generation task), check Multimodal-Symbolic-Regression repository.

Citation

If you find the paper or the repo helpful, please cite it with

@inproceedings{
anonymous2024snip,
title={{SNIP}: Bridging Mathematical Symbolic and Numeric Realms with Unified Pre-training},
author={Anonymous},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=KZSEgJGPxu}
}

License

This repository is licensed under MIT licence.

This work is built on top of other open source projects, including Deep Learning for Symbolic Mathematics and Contrastive Language-Image Pretraining. We thank the original contributors of these works for open-sourcing their valuable source codes.

Contact Us

For any questions or issues, you are welcome to open an issue in this repo, or contact us at [email protected], and [email protected].

multimodal-math-pretraining's People

Contributors

mmeidani avatar parshinsh avatar parshinshojaee avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

zzsong1023

multimodal-math-pretraining's Issues

Mismatch Error When Reloading Pre-trained Weights for Inference Testing

Hi,

I am encountering a mismatch error while attempting to reload the pre-trained weights (SNIP-10dmax) for inference testing on the model. Despite following the installation instructions meticulously to set up the environment and packages, the error arises when executing the command:

!python /content/drive/MyDrive/LLM/Multimodal-Math-Pretraining-main/train.py --reload_model /content/drive/MyDrive/LLM/Multimodal-Math-Pretraining-main/weights/snip-10dmax.pth

to load the weights. The specific error message indicates a size mismatch in several layers:

RuntimeError: Error(s) in loading state_dict for LinearPointEmbedder:
size mismatch for hidden_layers.0.weight: copying a param with shape torch.Size([2112, 2112]) from checkpoint, the shape in current model is torch.Size([384, 384]).
size mismatch for hidden_layers.0.bias: copying a param with shape torch.Size([2112]) from checkpoint, the shape in current model is torch.Size([384]).
size mismatch for fc.weight: copying a param with shape torch.Size([512, 2112]) from checkpoint, the shape in current model is torch.Size([512, 384]).

This issue leads me to question whether there might be a discrepancy between the pre-trained model weights provided and the current model architecture in the repository, or if there were any steps I may have overlooked during the setup process.

I would greatly appreciate any guidance or suggestions you could offer to resolve this mismatch issue. Thank you for your time and assistance.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.