Comments (1)
Never mind, I came up with a more elegant solution for using joblib
. I'm posting the code here for people who may want to use this approach. I can also add this as an example to the Optuna examples repo, if interested:
import contextlib
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, cast
import optuna
from optuna.exceptions import DuplicatedStudyError, ExperimentalWarning
from optuna.pruners import BasePruner, HyperbandPruner
from optuna.samplers import BaseSampler, TPESampler
from optuna.storages import JournalFileStorage, JournalStorage
from optuna.study import MaxTrialsCallback, Study
from optuna.trial import Trial, TrialState
@dataclass
class StudyConfig:
study_name: str
sampler: BaseSampler
pruner: BasePruner
directions: list[Literal["minimize", "maximize"]]
storage: JournalStorage
n_trials: int
n_cores: int = 1
log_path: Path = Path("optuna_journal.log")
study_path: Path = Path("optuna_study.pkl")
@property
def study_args(self) -> dict[str, Any]:
return {
"study_name": self.study_name,
"sampler": self.sampler,
"pruner": self.pruner,
"directions": self.directions,
"storage": self.storage,
}
def objective(trial: Trial) -> float:
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
return x**2 + y
def optimize(study_cfg: StudyConfig, worker_id: int) -> None:
study = optuna.create_study(**study_cfg.study_args, load_if_exists=True)
n_trials = study_cfg.n_trials // study_cfg.n_cores
n_trials += study_cfg.n_cores - (study_cfg.n_trials % study_cfg.n_cores)
study.optimize(
objective,
n_trials=n_trials,
callbacks=[MaxTrialsCallback(study_cfg.n_trials, states=(TrialState.COMPLETE,))],
)
if worker_id == 0:
with study_cfg.study_path.open("wb") as f:
pickle.dump(study, f)
n_trials = 6000
n_cores = 12
log_path = Path("optuna_journal.log")
log_path.unlink(missing_ok=True)
Path(f"{log_path}.lock").unlink(missing_ok=True)
study_path = Path("optuna_study.pkl")
study_path.unlink(missing_ok=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore", ExperimentalWarning)
study_cfg = StudyConfig(
"test",
TPESampler(seed=42),
HyperbandPruner(),
["minimize"],
JournalStorage(JournalFileStorage(str(log_path))),
n_trials,
n_cores,
log_path,
study_path,
)
with contextlib.suppress(DuplicatedStudyError):
_ = optuna.create_study(**study_cfg.study_args)
while study_cfg.n_trials >= min(100, study_cfg.n_trials):
try:
_ = joblib.Parallel(n_jobs=n_cores)(
joblib.delayed(optimize)(study_cfg, i) for i in range(n_cores)
)
except Exception:
Path(f"{log_path}.lock").unlink(missing_ok=True)
study_cfg.n_trials //= 2
else:
break
with study_cfg.study_path.open("rb") as f:
study = cast("Study", pickle.load(f))
best_params = study.best_trial.params
best_params
Note that, sometimes, depending on the number of trials, it fails with Error: did not possess lock
, that's why I added that while-loop.
from optuna.
Related Issues (20)
- Why is optuna stressing my CPU instead of GPU when device is set to "cuda"?
- ModuleNotFoundError: No module named 'MySQLdb'
- Speed up `TPESampler` using approximation in standard normal related computation
- Unable to use Redis: AttributeError: 'JournalRedisStorage' object has no attribute 'create_new_study' HOT 2
- Conda package not updated HOT 3
- Specify the removing version on `convert_positional_args`
- GridSampler: Duplicates HOT 5
- Wrong testing in tests with minimum versions
- Any bugs when using NumPy v2.0.0 HOT 2
- Expand the type of `callbacks` in `optimize` to `Iterable` HOT 2
- Fix http links underscore HOT 2
- Why does ` CmaEsSampler` expect at least 2D continuous space?
- show trend in plot_optimization_history HOT 3
- Priority or direction in `suggest_trial` for numeric and categorical parameters for biased optimization? HOT 4
- Prohibit set as a return of `sample_relative`
- Speed up MOTPE with cache
- Replace relative import path with absolute import path HOT 2
- optuna/artifacts/_filesystem.py HOT 1
- Enhance `optuna/optuna/visualization` by eliminating redundant for-loops and repetitive code HOT 1
- Parallelization with Postgresql HOT 6
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from optuna.