Git Product home page Git Product logo

Comments (3)

saitcakmak avatar saitcakmak commented on June 13, 2024 2

Hi @Runyu-Zhang. We do not have proper support for this use case -- we used to have some support but we removed it since it was very difficult to use correctly. The Ax Model & ModelBridge layer is designed to transform the data into BoTorch datasets, use these to construct & fit the BoTorch models. Why are you trying to utilize a pre-trained model rather than letting Ax train the model using the same fit_gpytorch_mll method? Do you need to pass in certain arguments to the BoTorch model that the API currently doesn't support? If so, I'd like to learn about these and see if it'd make sense for us to support them more generally. In general, I'd strongly recommend following this tutorial (https://botorch.org/tutorials/custom_botorch_model_in_ax) to create a custom BoTorch model class and let Ax do the training.

from ax.

Runyu-Zhang avatar Runyu-Zhang commented on June 13, 2024

I reformatted the code.

Code

def create_gp_model_kwargs():
    return {'surrogate' : Surrogate(botorch_model_class = train_default_gp_model()),

            'botorch_acqf_class': qNoisyExpectedHypervolumeImprovement}

def train_default_gp_model(self):

    train_x, train_y = initialize_training_data()

    train_x = normalize(X=train_x_motors, bounds=params_bounds_tensor)

    models = []

    for i in range(train_y.shape[-1]):

        models.append(SingleTaskGP(train_X = train_x,

                                   train_Y = train_y[:, i].unsqueeze(-1),

                                   outcome_transform = Standardize(m = 1)))

        model = ModelListGP(*models)

        mll = SumMarginalLogLikelihood(model.likelihood, model)

        fit_gpytorch_mll(mll)

    return model

parameters_list = [dict(name = param,
type = 'range',
bounds = range_values,
value_type = 'float') 
                   for param in all_parameters]
objectives = {obj: ObjectiveProperties(minimize=True, threshold=hypervolume_reference_point) 
              for obj in all_objectives}
steps = [GenerationStep(model = Models.SOBOL, # Add 1 sobol to avoid "ax.exceptions.core.DataRequiredError: StandardizeY transform requires non-empty data."

                        num_trials = 1),

         GenerationStep(model = Models.BOTORCH_MODULAR,

                        num_trials. = num_bo_trials,

                        model_kwargs = create_gp_model_kwargs())]

# Create Ax Client and optimize:

ax_client = AxClient(generation_strategy = GenerationStrategy(steps=steps),

                     random_seed = fixed_seed,

                     verbose_logging = True)

ax_client.create_experiment(name = 'Ax_Modular',

                            parameters = parameters_list,

                            objectives = objectives,

                            overwrite_existing_experiment = True)
for i in range(BO_trials):

    suggested_values, trial_index = ax_client.get_next_trial()
    # Throw the error after one Sobol trial
    loss_values = evaluate_suggested_values(suggested_values=suggested_values)
 
    ax_client.complete_trial(trial_index=trial_index, raw_data=loss_values.copy())

from ax.

Cesar-Cardoso avatar Cesar-Cardoso commented on June 13, 2024

Hello there! In this case you're getting an exception because botorch_model_class should be a class type rather than an instance. Your train_default_gp_model() method returns an instance of a Model (specifically a ModelListGP). When this check happens you get an exception.

Perhaps what you want to do is pass a callable model argument to GenerationStep as described here?

from ax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.