Git Product home page Git Product logo

slurm_sweeps's Introduction

slurm sweeps logo
slurm sweeps

A simple tool to perform parameter sweeps on SLURM clusters.

License Codecov

The main motivation was to provide a lightweight ASHA implementation for SLURM clusters that is fully compatible with pytorch-lightning's ddp.

It is heavily inspired by tools like Ray Tune and Optuna. However, on a SLURM cluster, these tools can be complicated to set up and introduce considerable overhead.

Slurm sweeps is simple, lightweight, and has few dependencies. It uses SLURM Job Steps to run the individual trials.

Installation

pip install slurm-sweeps

Dependencies

  • cloudpickle
  • numpy
  • pandas
  • pyyaml

Usage

You can just run this example on your laptop. By default, the maximum number of parallel trials equals the number of CPUs on your machine.

""" Content of test_ss.py """
from time import sleep
import slurm_sweeps as ss


# Define your train function
def train(cfg: dict):
    for epoch in range(cfg["epochs"]):
        sleep(0.5)
        loss = (cfg["parameter"] - 1) ** 2 / (epoch + 1)
        # log your metrics
        ss.log({"loss": loss}, epoch)


# Define your experiment
experiment = ss.Experiment(
    train=train,
    cfg={
        "epochs": 10,
        "parameter": ss.Uniform(0, 2),
    },
    asha=ss.ASHA(metric="loss", mode="min"),
)


# Run your experiment
result = experiment.run(n_trials=1000)

# Show the best performing trial
print(result.best_trial())

Or submit it to a SLURM cluster. Write a small SLURM script test_ss.slurm that runs the code above:

#!/bin/bash -l
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=18
#SBATCH --cpus-per-task=4
#SBATCH --mem-per-cpu=1GB

python test_ss.py

By default, this will run $SLURM_NTASKS trials in parallel. In the case above: 2 nodes * 18 tasks = 36 trials

Then submit it to the queue:

sbatch test_ss.slurm

See the tests folder for an advanced example of training a PyTorch model with Lightning's DDP.

API Documentation

CLASS slurm_sweeps.Experiment

class Experiment(
    train: Callable,
    cfg: Dict,
    name: str = "MySweep",
    local_dir: Union[str, Path] = "./slurm-sweeps",
    asha: Optional[ASHA] = None,
    slurm_cfg: Optional[SlurmCfg] = None,
    restore: bool = False,
    overwrite: bool = False,
)

Set up an HPO experiment.

Arguments:

  • train - A train function that takes as input the cfg dict.
  • cfg - A dict passed on to the train function. It must contain the search spaces via slurm_sweeps.Uniform, slurm_sweeps.Choice, etc.
  • name - The name of the experiment.
  • local_dir - Where to store and run the experiments. In this directory, we will create the database slurm_sweeps.db and a folder with the experiment name.
  • slurm_cfg - The configuration of the Slurm backend responsible for running the trials. We automatically choose this backend when slurm sweeps is used within an sbatch script.
  • asha - An optional ASHA instance to cancel less promising trials.
  • restore - Restore an experiment with the same name?
  • overwrite - Overwrite an existing experiment with the same name?

Experiment.name

@property
def name() -> str

The name of the experiment.

Experiment.local_dir

@property
def local_dir() -> Path

The local directory of the experiment.

Experiment.run

def run(
    n_trials: int = 1,
    max_concurrent_trials: Optional[int] = None,
    summary_interval_in_sec: float = 5.0,
    nr_of_rows_in_summary: int = 10,
    summarize_cfg_and_metrics: Union[bool, List[str]] = True
) -> pd.DataFrame

Run the experiment.

Arguments:

  • n_trials - Number of trials to run. For grid searches, this parameter is ignored.
  • max_concurrent_trials - The maximum number of trials running concurrently. By default, we will set this to the number of cpus available, or the number of total Slurm tasks divided by the number of tasks requested per trial.
  • summary_interval_in_sec - Print a summary of the experiment every x seconds.
  • nr_of_rows_in_summary - How many rows of the summary table should we print?
  • summarize_cfg_and_metrics - Should we include the cfg and the metrics in the summary table? You can also pass in a list of strings to only select a few cfg and metric keys.

Returns:

A summary of the trials in a pandas DataFrame.

CLASS slurm_sweeps.ASHA

class ASHA(
    metric: str,
    mode: str,
    reduction_factor: int = 4,
    min_t: int = 1,
    max_t: int = 50,
)

Basic implementation of the Asynchronous Successive Halving Algorithm (ASHA) to prune unpromising trials.

Arguments:

  • metric - The metric you want to optimize.
  • mode - Should the metric be minimized or maximized? Allowed values: ["min", "max"]
  • reduction_factor - The reduction factor of the algorithm
  • min_t - Minimum number of iterations before we consider pruning.
  • max_t - Maximum number of iterations.

ASHA.metric

@property
def metric() -> str

The metric to optimize.

ASHA.mode

@property
def mode() -> str

The 'mode' of the metric, either 'max' or 'min'.

ASHA.find_trials_to_prune

def find_trials_to_prune(database: "pd.DataFrame") -> List[str]

Check the database and find trials to prune.

Arguments:

  • database - The experiment's metrics table of the database as a pandas DataFrame.

Returns:

List of trial ids that should be pruned.

CLASS slurm_sweeps.SlurmCfg

@dataclass
class SlurmCfg:
  exclusive: bool = True
  nodes: int = 1
  ntasks: int = 1
  args: str = ""

A configuration class for the SlurmBackend.

Arguments:

  • exclusive - Add the --exclusive switch.
  • nodes - How many nodes do you request for your srun?
  • ntasks - How many tasks do you request for your srun?
  • args - Additional command line arguments for srun, formatted as a string.

CLASS slurm_sweeps.Result

class Result(
    experiment: str,
    local_dir: Union[str, Path] = "./slurm-sweeps",
)

The result of an experiment.

Arguments:

  • experiment - The name of the experiment.
  • local_dir - The directory where we find the slurm-sweeps.db database.

Result.experiment

@property
def experiment() -> str

The name of the experiment.

Result.trials

@property
def trials() -> List[Trial]

A list of the trials of the experiment.

Result.best_trial

def best_trial(
    metric: Optional[str] = None,
    mode: Optional[str] = None
) -> Trial

Get the best performing trial of the experiment.

Arguments:

  • metric - The metric. By default, we take the one defined by ASHA.
  • mode - The mode of the metric, either 'min' or 'max'. By default, we take the one defined by ASHA.

Returns:

The best trial.

CLASS slurm_sweeps.trial.Trial

@dataclass
class Trial:
    cfg: Dict
    process: Optional[subprocess.Popen] = None
    start_time: Optional[datetime] = None
    end_time: Optional[datetime] = None
    status: Optional[Union[str, Status]] = None
    metrics: Optional[Dict[str, Dict[int, Union[int, float]]]] = None

A trial of an experiment.

Arguments:

  • cfg - The config of the trial.
  • process - The subprocess that runs the trial.
  • start_time - The start time of the trial.
  • end_time - The end time of the trial.
  • status - Status of the trial. If process is not None, we will always query the process for the status.
  • metrics - Logged metrics of the trial.

Trial.trial_id

@property
def trial_id() -> str

The trial ID is a 6-digit hash from the config.

Trial.runtime

@property
def runtime() -> Optional[timedelta]

The runtime of the trial.

Trial.is_terminated

def is_terminated() -> bool

Return True, if the trial has been completed or pruned.

FUNCTION slurm_sweeps.log

def log(metrics: Dict[str, Union[float, int]], iteration: int)

Log metrics to the database.

If ASHA is configured, this also checks if the trial needs to be pruned.

Arguments:

  • metrics - A dictionary containing the metrics.
  • iteration - Iteration of the metrics. Most of the time this will be the epoch.

Raises:

  • TrialPruned if the holy ASHA says so!
  • TypeError if a metric is not of type float or int.

Contact

David Carreto Fidalgo ([email protected])

slurm_sweeps's People

Contributors

dcfidalgo avatar

Stargazers

 avatar  avatar

Watchers

 avatar

Forkers

xxemmexx

slurm_sweeps's Issues

Move storage to database

The database should function as storage, for each experiment make 3 tables:

  • {experiment}: Holds the trials, nr rows == nr trials
  • {experiment}_metrics: holds the logged metrics
  • {experiment}_storage: holds the pickled train function and asha class as blobs

Database class should not create file on init

  • init should not create anything
  • Move the creation of the file and the storage table to the create method
  • optionally check for the db file in the self._connection context
  • maybe give each experiment its own storage table?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.