Search with Optuna#

Optuna is a flexible, full-featured hyperparameter search library that pairs nicely with the MosaicML platform. Optuna makes it easy to configure your hyperparameter search space and optimization algorithms. Here we’ll briefly describe an example using Optuna with the MCLI Python API to search over both learning rate and label smoothing's”smoothing” parameter. Rather than go through an entire example line-by-line, we’ll first go through important components of the example and then provide a full, working script at the bottom of this document.

First, a note on retrieving run metrics#

The MosaicML platform does not currently support storage of metrics data from individual runs. Because of this, any optimization strategies that require feedback from the submitted runs (i.e., not grid or random search) will require a separate interface. In this tutorial, we’ll describe how to do this using an S3 bucket and Composer’s composer.loggers.object_store_logger.ObjectStoreLogger. We have chosen to describe a more generic use-case but we would highly recommend pairing this with an experiment tracker like W&B or CometML.

Setting up Optuna and launching runs#

The core component of an Optuna search script is the objective function. Since we are performing our training on the MosaicML platform, our objective function will just do the following:

  1. Create a default run configuration or fetch it

  2. Get new hyperparameters and update the run configuration

  3. Create the run

  4. Wait for the run to be scheduled and start

  5. Monitor the run’s metrics and update Optuna with them

  6. Optionally, stop the run if it is not performing well enough

Create our base configuration#

For the first step, we need to create a default RunConfig. First, define your chosen cluster:

cluster = "<your-cluster>"

Next, fetch your base parameters from the composer MNIST defaults:

import requests
import yaml

# Download the MNIST YAML from Composer
req = requests.get('https://raw.githubusercontent.com/mosaicml/composer/v0.9.0/composer/yamls/models/classify_mnist.yaml')

print(f'Downloaded Composer\'s MNIST configuration:\n\n{req.text}')

# Load the YAML into a parameters dictionary
parameters = yaml.safe_load(req.text)

Then, set your command to use a composer example entrypoint:

# Define our command
command = """
wget -O entrypoint.py https://raw.githubusercontent.com/mosaicml/composer/v0.9.0/examples/run_composer_trainer.py

composer entrypoint.py -f /mnt/config/parameters.yaml
"""

With that, we can define our base run configuration:

from mcli import *

config = RunConfig(name='mnist-classification',
                   image='mosaicml/composer:0.9.0',
                   compute={'gpus': 1, 'cluster': cluster},
                   command=command,
                   parameters=parameters)

Alternatively, it’s often nice to define your base configuration for the RunConfig in YAML format. Take a look at the Run schema page to see how that is done. If you go that route, you can instead use:

config = RunConfig.from_file('/path/to/base.yaml')

Get new hyperparameters and update our config#

With our base config in hand, we can now modify it using the hyperparameters taken from Optuna’s Trial object. A Trial instance will be provided to the objective function we are building and can be used to “suggest” new hyperparameters. The full example that defines our objective function is available at the bottom of this example. First, we set hyperparameters using the Optuna Trial:

# Get hyperparameters
smoothing = trial.suggest_float('smoothing', 0, 1)
lr = trial.suggest_float('lr', 0.0001, 1, log=True)
print(f'Chose smoothing: {smoothing} and lr: {lr}')

These values can now be used to update config.parameters and give our run a custom name:

# Update parameters
config.parameters['algorithms'] = {'label_smoothing': {'smoothing': smoothing}}
config.parameters['optimizers']['sgd']['lr'] = lr

# Also, we can update the run name to include these values
# We also replace '.' with 'p', since '.' is not a valid run name character
config.name = f'{config.name}-lr-{lr}-sm-{smoothing}'.replace('.', 'p')

Once our configuration is updated completely (see also changes for S3 logging below), we can create our run and wait for it to start:

# Submit the run
run = create_run(config)
print(f'Submitted run {run.name}')

# Wait for the run to start running
print('Waiting for run to start')
run = wait_for_run_status(run.name, status='running')

Finally, to provide Optuna with feedback as the run progresses, we’ll use the follow_run_in_s3 function that we’ll define below. If you are not using S3, you’ll want a similar function that takes the Run and any other inputs and returns an iterator over metrics values. Here, TARGET_METRIC should be defined as the name of the metric output by your composer logger. For MNIST, this is:

TARGET_METRIC = 'metrics/eval/Accuracy'

Now, just follow the metrics and add them using trial.report:

for metrics in follow_run_in_s3(run,
                                bucket=BUCKET_NAME,
                                results_path=artifact,
                                log_level=log_level):
    if metrics and TARGET_METRIC in metrics:
        value = metrics[TARGET_METRIC]
        step = metrics['epoch']
        trial.report(value=value, step=step)
        print(f'Step: {step}, Value: {value}')

        # Stop the run if the study says we should
        if trial.should_prune():
            stop_runs([run])
            print(f'Run {run.name} was pruned. Stopping...')
            raise optuna.TrialPruned()

Fetch run feedback through files stored in S3#

As mentioned, fetching run feedback can be a little tricky. We highly recommend using an experiment tracker, as any of those will provide a much simpler interface for fetching the performance of an ongoing run. Using S3 is an option, however, so we’ll show how you can periodically download new versions of a file from S3 and extract metrics from it.

Configuring your S3 credentials#

First things first, you’ll need access to an S3 bucket both within the MosaicML platform and on your local machine. If, for instance, you are using AWS for your data storage, follow these instructions to set up your local machine to read from your buckets. You’ll then need to create an S3 secret to allow you to upload metrics from within your MosaicML platform cluster.

Fetching metrics from S3#

To fetch metrics from S3 we’ll need a few things:

  1. Declare a bucket name

  2. Define a function that regularly checks and downloads metrics from S3

  3. Define a class to parse that file and return new metrics

At the top of the script below, you’ll see a spot to declare your bucket name:

# The name of your S3 bucket where you want to store results
BUCKET_NAME = '<my-S3-bucket-name>'

Now we’ll define our function to check and download metrics from S3:

def follow_run_in_s3(run: Run,
                     bucket: str,
                     results_path: str,
                     log_level: str = 'epoch') -> Generator[Dict[str, Any], None, None]:
    """Follow the metrics for a specific run in S3

    Args:
        run: The submitted run that should be followed
        bucket: The bucket in which the file lives
        artifact_path: Path to the file in the bucket
        log_level: The `composer` log level. Can be either 'epoch' or 'batch'. Default: 'epoch'

    Yields:
        Generator[Dict[str, Any], None, None]: A generator yielding metrics from the run
    """

    metrics_file = MetricsFile(log_level=log_level)
    last_file_size = 0
    while True:
        try:
            run = wait_for_run_status(run, status='completed', timeout=10)
            print(f'Run {run.name} has completed with status: {run.status}')
        except TimeoutError:
            pass

        file_size = check_file_size(bucket, results_path)
        if file_size != last_file_size:
            last_file_size = file_size
            text = download_s3_file(bucket, results_path)
            metrics = metrics_file.get_new_metrics(text)
            for step_metrics in metrics:
                yield step_metrics

        if run.status.after(RunStatus.COMPLETED, inclusive=True):
            return

This function takes a Run, along with file details and yields individual metrics dictionaries from the file. The basic steps for each iteration are:

  1. Wait a bit so we don’t poll too often. We have this set to 10 seconds and the wait_for_run_status() method will timeout if the run does not complete in that time

  2. Check if the file has changed at all by querying the file size

  3. Download the new file, if necessary

  4. Extract any new metrics

  5. If the run has completed, exit

Steps 2 and 3 are pretty straightforward, so we’ll leave the implementation of those to the script. Let’s now turn our focus to step 4. We accomplish this by defining a MetricsFile class whose main role is to parse the metrics file uploaded by composer and maintain some amount of state on which metrics we’ve seen.

The key method from this class will be the parse_line_metrics, which looks at each new line in the file and returns a list of metrics dictionaries extracted from it.

class MetricsFile():
    """Parser of metrics files output from `composer`'s `FileLogger`

    Args:
        log_level: The `composer` log level. Can be either 'epoch' or 'batch'. Default: 'epoch'
    """

    def __init__(self, log_level: str = 'epoch'):
        self.log_level = log_level
        self._last_line_no = 0

    @property
    def pattern(self) -> str:
        """Regex pattern to match metrics lines in the `composer` `FileLogger`
        """

        matching_logic = r"=(\d+)\]\[batch=(\d+)[^:]*:\s{(.*)}"
        out = r"\[stderr\]: \[{}\]\[epoch".format(self.log_level.upper()) + matching_logic

        return out

    def contains_metrics_keys(self, line) -> bool:
        """Quickly check if common metrics keys are present in the line
        """
        for pattern in ["metrics/", "loss/"]:
            if pattern in line:
                return True

        return False

    def parse_line_metrics(self, line: str) -> Optional[Dict[str, Any]]:
        """Parse an individual line of metrics

        Args:
            line: A log line from a file

        Returns:
            Optional[Dict[str, Any]]: If the line is a metrics line, returns the metrics as a dictionary.
                Returns None otherwise.
        """

        match = re.match(self.pattern, line)
        if match:
            assert len(match.groups()) == 3, f'Invalid match: {line}'
            epoch, batch, json_metrics = match.groups()

            if self.contains_metrics_keys(json_metrics):
                json_metrics = '{' + json_metrics.strip().rstrip(',') + '}'
                metrics = json.loads(json_metrics)
                metrics['epoch'] = int(epoch)
                metrics['batch'] = int(batch)
                metrics['log_level'] = self.log_level
                return metrics

    def get_new_metrics(self, text: str) -> List[Dict[str, float]]:
        """Get any new metrics dictionaries from the log text

        Args:
            text: Log text that may include metrics that have already been parsed

        Returns:
            List[Dict[str, float]]: A list of new metrics dictionaries
        """

        lines = text.splitlines()[self._last_line_no:]
        metrics = []
        for line in lines:
            line_metrics = self.parse_line_metrics(line)
            if line_metrics:
                metrics.append(line_metrics)
        self._last_line_no += len(lines)
        return metrics

Putting it all together#

After adding some Optuna boiler plate to hook our objective function up to a Study, we’re left with the script below. Give it a try!

import copy
import json
import logging
import os
import re
import sys
import tempfile
from concurrent.futures import TimeoutError
from typing import Any, Callable, Dict, Generator, List, Optional

import boto3
import optuna
import requests
import yaml
from botocore.config import Config
from botocore.exceptions import ClientError
from mcli.api.runs import (Run, RunConfig, RunStatus, create_run, stop_runs,
                                wait_for_run_status)

# The name of your S3 bucket where you want to store results
BUCKET_NAME: str = '<your-bucket-name>'

# The cluster you wish to run on - If you only have access to 1, you can leave it as None
CLUSTER: Optional[str] = '<your-cluster>'

# The metric you want to optimize
TARGET_METRIC = 'metrics/eval/Accuracy'

# The name of your study
STUDY_NAME = 'mnist-classification'


def get_base_config() -> RunConfig:

    # Download the MNIST YAML from Composer
    req = requests.get(
        'https://raw.githubusercontent.com/mosaicml/composer/v0.9.0/composer/yamls/models/classify_mnist.yaml'
    )

    print(f'Downloaded Composer\'s MNIST configuration:\n\n{req.text}')

    # Load the YAML into a parameters dictionary
    parameters = yaml.safe_load(req.text)
    # Define our command
    command = """
    wget -O entrypoint.py https://raw.githubusercontent.com/mosaicml/composer/v0.9.0/examples/run_composer_trainer.py

    composer entrypoint.py -f /mnt/config/parameters.yaml
    """
    config = RunConfig(name='mnist-classification',
                       image='mosaicml/composer:0.9.0',
                       compute={'gpus': 1, 'cluster': CLUSTER},
                       command=command,
                       parameters=parameters)
    return config


class MetricsFile():
    """Parser of metrics files output from `composer`'s `FileLogger`

    Args:
        log_level: The `composer` log level. Can be either 'epoch' or 'batch'. Default: 'epoch'
    """

    def __init__(self, log_level: str = 'epoch'):
        self.log_level = log_level
        self._last_line_no = 0

    @property
    def pattern(self) -> str:
        """Regex pattern to match metrics lines in the `composer` `FileLogger`
        """

        matching_logic = r"=(\d+)\]\[batch=(\d+)[^:]*:\s{(.*)}"
        out = r"\[stderr\]: \[{}\]\[epoch".format(self.log_level.upper()) + matching_logic

        return out

    def contains_metrics_keys(self, line) -> bool:
        """Quickly check if common metrics keys are present in the line
        """
        for pattern in ["metrics/", "loss/"]:
            if pattern in line:
                return True

        return False

    def parse_line_metrics(self, line: str) -> Optional[Dict[str, Any]]:
        """Parse an individual line of metrics

        Args:
            line: A log line from a file

        Returns:
            Optional[Dict[str, Any]]: If the line is a metrics line, returns the metrics as a dictionary.
                Returns None otherwise.
        """

        match = re.match(self.pattern, line)
        if match:
            assert len(match.groups()) == 3, f'Invalid match: {line}'
            epoch, batch, json_metrics = match.groups()

            if self.contains_metrics_keys(json_metrics):
                json_metrics = '{' + json_metrics.strip().rstrip(',') + '}'
                metrics = json.loads(json_metrics)
                metrics['epoch'] = int(epoch)
                metrics['batch'] = int(batch)
                metrics['log_level'] = self.log_level
                return metrics

    def get_new_metrics(self, text: str) -> List[Dict[str, float]]:
        """Get any new metrics dictionaries from the log text

        Args:
            text: Log text that may include metrics that have already been parsed

        Returns:
            List[Dict[str, float]]: A list of new metrics dictionaries
        """

        lines = text.splitlines()[self._last_line_no:]
        metrics = []
        for line in lines:
            line_metrics = self.parse_line_metrics(line)
            if line_metrics:
                metrics.append(line_metrics)
        self._last_line_no += len(lines)
        return metrics


def check_file_size(bucket: str, artifact_path: str) -> int:
    """Check the file size for the file at the specified path

    Args:
        bucket: The bucket in which the file lives
        artifact_path: Path to the file in the bucket

    Returns:
        int: File size, in bytes, of the specified file. If the file does not exist, returns 0
    """
    s3 = boto3.client('s3')
    try:
        response = s3.head_object(Bucket=bucket, Key=artifact_path)
        return int(response['ContentLength'])
    except ClientError as e:
        if '(404)' in str(e):
            # File does not exist yet, so same as empty
            return 0
        raise


def download_s3_file(bucket: str, artifact_path: str) -> str:
    """Download the specified file from S3

    Args:
        bucket: The bucket in which the file lives
        artifact_path: Path to the file in the bucket

    Returns:
        str: Text content of the file
    """

    config = Config(connect_timeout=10,
                    read_timeout=10,
                    retries={'total_max_attempts': 2})
    s3 = boto3.client('s3', config=config)

    with tempfile.TemporaryDirectory() as dirname:
        saveas = os.path.join(dirname, artifact_path)
        os.makedirs(os.path.dirname(saveas), exist_ok=True)
        try:
            s3.download_file(bucket, artifact_path, saveas)
        except ClientError as e:
            if '(404)' in str(e):
                # File does not exist yet, so same as empty
                return ''
            raise
        else:
            with open(saveas, 'r') as f:
                return f.read()


def follow_run_in_s3(run: Run,
                     bucket: str,
                     results_path: str,
                     log_level: str = 'epoch') -> Generator[Dict[str, Any], None, None]:
    """Follow the metrics for a specific run in S3

    Args:
        run: The submitted run that should be followed
        bucket: The bucket in which the file lives
        artifact_path: Path to the file in the bucket
        log_level: The `composer` log level. Can be either 'epoch' or 'batch'. Default: 'epoch'

    Yields:
        Generator[Dict[str, Any], None, None]: A generator yielding metrics from the run
    """

    metrics_file = MetricsFile(log_level=log_level)
    last_file_size = 0
    while True:
        try:
            run = wait_for_run_status(run, status='completed', timeout=10)
            print(f'Run {run.name} has completed with status: {run.status}')
        except TimeoutError:
            pass

        file_size = check_file_size(bucket, results_path)
        if file_size != last_file_size:
            last_file_size = file_size
            text = download_s3_file(bucket, results_path)
            metrics = metrics_file.get_new_metrics(text)
            for step_metrics in metrics:
                yield step_metrics

        if run.status.after(RunStatus.COMPLETED, inclusive=True):
            return


def mnist_submit(trial: optuna.trial.Trial, log_level: str = 'epoch') -> float:

    # Create run from default YAML
    print('Getting default run')
    config = get_base_config()

    # Get hyperparameters
    smoothing = trial.suggest_float('smoothing', 0, 1)
    lr = trial.suggest_float('lr', 0.0001, 1, log=True)
    print(f'Chose smoothing: {smoothing} and lr: {lr}')

    # Update parameters
    config.parameters['algorithms'] = {'label_smoothing': {'smoothing': smoothing}}
    config.parameters['optimizers']['sgd']['lr'] = lr

    # Also, we can update the run name to include these values
    # We also replace '.' with 'p', since '.' is not a valid run name character
    config.name = f'{config.name}-lr-{lr:1.2e}-sm-{smoothing:.2f}'.replace('.', 'p')

    # Set bucket and artifact names - this will let us access results stored in S3
    artifact = f'{trial.study.study_name}/{config.name}/metrics.txt'
    loggers = config.parameters.setdefault('loggers', {})
    loggers['object_store'] = {'object_store_hparams': {'s3': {'bucket': BUCKET_NAME}}}
    loggers['file'] = {'filename': artifact}

    # Submit the run
    run = create_run(config)
    print(f'Submitted run {run.name}')

    # Wait for the run to start running
    print('Waiting for run to start')
    run = wait_for_run_status(run.name, status='running')

    # Follow the run's outputs in S3
    print('Following results in s3')
    best_value = -9999
    for metrics in follow_run_in_s3(run,
                                    bucket=BUCKET_NAME,
                                    results_path=artifact,
                                    log_level=log_level):
        if metrics and TARGET_METRIC in metrics:
            value = metrics[TARGET_METRIC]
            step = metrics['epoch']
            trial.report(value=value, step=step)
            print(f'Step: {metrics["epoch"]}, Value: {metrics[TARGET_METRIC]}')

            if value > best_value:
                best_value = value

            # Stop the run if the study says we should
            if trial.should_prune():
                stop_runs([run])
                print(f'Run {run.name} was pruned. Stopping...')
                raise optuna.TrialPruned()

    return best_value


def _print_study(study: optuna.study.Study):
    print("Best params: ", study.best_params)
    print("Best value: ", study.best_value)
    print("Best Trial: ", study.best_trial)


def start_dbms_study(study_name: str, storage_path: str,
                     objective: Callable[[optuna.trial.Trial], float]):
    # Add stream handler of stdout to show the messages
    optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
    # storage_name = "sqlite:///{}.db".format(study_name)
    study: optuna.study.Study = optuna.create_study(
        study_name=study_name,
        storage=storage_path,
        direction='maximize',
        pruner=optuna.pruners.PercentilePruner(25.0,
                                               n_startup_trials=5,
                                               n_warmup_steps=30,
                                               interval_steps=10))

    study.optimize(objective, n_trials=10, n_jobs=2)

    print("================================ Finished initial study")
    _print_study(study)


def resume_dbms_study(study_name: str, storage_path: str,
                      objective: Callable[[optuna.trial.Trial], float]):
    study = optuna.create_study(study_name=study_name,
                                storage=storage_path,
                                load_if_exists=True,
                                direction='maximize')
    study.optimize(objective, n_trials=3, n_jobs=1)

    print("================================ Resumed study and got")
    _print_study(study)


def main_dbms():
    # see https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/001_rdb.html#sphx-glr-tutorial-20-recipes-001-rdb-py

    objective = mnist_submit  # train a model with composer

    # wrap in tempdir so that you can run the script multiple times without
    # it complaining that 'mystudy' already exists and we need to specify
    # load_if_exists in create_study().
    with tempfile.TemporaryDirectory() as tmpdir:
        storage_path = "sqlite:///{}/{}.db".format(tmpdir, STUDY_NAME)
        start_dbms_study(STUDY_NAME, storage_path, objective)
        # note that resumption doesn't require the original study object;
        # only needs path to the db where the study is stored
        resume_dbms_study(STUDY_NAME, storage_path, objective)


if __name__ == "__main__":
    main_dbms()