Pretraining#

Feature in Preview

Pretraining is currently in preview, and is liable to change significantly in the near future.

Pretraining gives you end-to-end ownership of your custom model.

Our pretraining API offers:

  1. A simple interface to our training stack to perform full model pretraining.

  2. Optimal, default hyperparameters and model training setup.

  3. Pretrained model checkpoints saved to remote store of your choice.

  4. Ability to customize your tokenizer.

  5. Use a mix of datasets for your model.

  6. Evaluate your model as your model pretrains.

We recommend to try pretraining if:

  • You have tried finetuning an existing model and want better results.

  • You have tried prompt engineering on an existing model and want better results.

  • You want full ownership over a custom model for data privacy.

  • You want to use your own tokenizer or vocabulary, specially for support in other languages.

Setup#

Before getting started with pretraining, make sure you have configured MosaicML access.

Data preparation and credentials#

The training data and eval data format required by the API is raw text, converted to MDS format.

Note that we currently don’t support reading from Unity Catalog datasets as an input yet and we are working on it.

Supported data sources#

If you are using a remote object store as the source of your training data, you must first create an MCLI secret with the credentials to access your data.

Note that the folder to save your checkpoints must be a remote object store, which will also require secrets configurations. We support the following data sources:

Data Source

Example

MCLI Secret

AWS S3

s3://bucket/...

AWS S3

OCI

oci://bucket/...

OCI

GCP

gs://bucket/...

GCP

Supported models#

We currently support pretraining on the following suite of models with a maximum context length of 4096:

Model

Parameters

Suggested tokens

Time to train with suggested tokens

databricks/dbrx-9b

9.2B total, 2.6B active

200B tok

1 day (128 h100s)

databricks/dbrx-18b

18.6B total, 5.2B active

400B tok

4 days (128 h100s)

databricks/dbrx-35b

35.7B total, 9.9B active

700B tok

7 days (256 h100s)

databricks/dbrx-73b

73.5B total, 20.1B active

1.5T tok

16 days (512 h100s)

A quick example#

Here is a minimal example of pretraining a model on a dataset.

model: databricks/dbrx-9b
train_data: s3://<my-bucket>/data
save_folder: s3://<my-bucket>/checkpoints
compute: 
  cluster: <cluster_name>
  gpus: 128

You can then launch this run and save checkpoints to your S3 bucket with the following command:

mcli train -f pretrain.yaml

You can also pass overrides to the yaml via the CLI command for the mandatory and optional fields:

mcli train -f pretrain.yaml \
  --model databricks/dbrx-9b \
  --train-data s3://<my-bucket>/data \
  --training-duration 10000tok

Experiment tracking#

We support both MLflow and WandB as experiment trackers to monitor and visualize the metrics for your pretraining run. Set experiment_tracker to contain the configuration for the tracker you want to use.

MLflow#

Provide the full path for the experiment, including the experiment name. In Databricks Managed MLflow, this will be a workspace path resembling /Users/example@domain.com/my_experiment. You can also provide a model_registry_path for model deployment. Make sure to configure your Databricks secret.

experiment_tracker:
  mlflow:
    experiment_path: /Users/[email protected]/my_experiment
    model_registry_path: catalog.schema | catalog.schema.model_name # optional

Weights & Biases#

Include both project name and entity name in your configuration, and make sure to set up your WandB secret.

experiment_tracker:
  wandb:
    project: my-project
    entity: my-entity

Launching a pretraining run#

Calling the pretrain API launches your run using the SDK, while the yaml needs to be launched with mcli train -f <your-yaml>. Refer to the example above and see the Pretraining Schema for more information about the parameters for the pretraining API.

The SDK result is a Run object.

mcli.Run(run_uid, name, status, created_at, updated_at, created_by, priority, preemptible, retry_on_system_failure, cluster, gpus, gpu_type, cpus, node_count, latest_resumption, is_deleted, run_type, max_retries=None, reason=None, nodes=<factory>, submitted_config=None, metadata=None, last_resumption_id=None, resumptions=<factory>, events=<factory>, lifecycle=<factory>, image=None, max_duration=None, _required_properties=('id', 'name', 'status', 'createdAt', 'updatedAt', 'reason', 'createdByEmail', 'priority', 'preemptible', 'retryOnSystemFailure', 'resumptions', 'isDeleted', 'runType'))[source]

A run that has been launched on the MosaicML platform

Parameters
  • run_uid (str) – Unique identifier for the run

  • name (str) – User-defined name of the run

  • status (RunStatus) – Status of the run at a moment in time

  • created_at (datetime) – Date and time when the run was created

  • updated_at (datetime) – Date and time when the run was last updated

  • created_by (str) – Email of the user who created the run

  • priority (str) – Priority of the run; defaults to auto but can be updated to low or lowest

  • preemptible (bool) – Whether the run can be stopped and re-queued by higher priority jobs

  • retry_on_system_failure (bool) – Whether the run should be retried on system failure

  • cluster (str) – Cluster the run is running on

  • gpus (int) – Number of GPUs the run is using

  • gpu_type (str) – Type of GPU the run is using

  • cpus (int) – Number of CPUs the run is using

  • node_count (int) – Number of nodes the run is using

  • latest_resumption (Resumption) – Latest resumption of the run

  • max_retries (Optional[int]) – Maximum number of times the run can be retried

  • reason (Optional[str]) – Reason the run was stopped

  • nodes (List[:class:`~mcli.api.model.run.Node]`) – Nodes the run is using

  • submitted_config (Optional[:class:`~mcli.models.run_config.RunConfig]`) – Submitted run configuration

  • metadata (Optional[Dict[str, Any]]) – Metadata associated with the run

  • last_resumption_id (Optional[str]) – ID of the last resumption of the run

  • resumptions (List[:class:`~mcli.api.model.run.Resumption]`) – Resumptions of the run

  • lifecycle (List[:class:`~mcli.api.model.run.RunLifecycle]`) – Lifecycle of the run

  • image (Optional[str]) – Image the run is using

See the Pretraining CLI and Pretraining SDK for more information on how to interact with your pretraining runs.

Looking for more configurability over the model training? Try creating a training run instead and see the LLM foundry pretraining documentation for more details.


Evaluate your model#

Our pretraining API provides a lightweight solution that runs evaluation during pretraining, under the eval.data_path & eval.prompts. More details available in the configuration spec

For complete evaluation after pretraining, our LLM evaluation framework provides an open-source framework for In-context learning (ICL) tasks. Example DBRX configurations are available in scripts/eval/yamls/dbrx-gauntlet

You can run these parameters via mcli with the following yaml:

name: eval-dbrx
image: mosaicml/llm-foundry:2.3.0_cu121_flash2-latest
integrations:
- integration_type: git_repo
  git_repo: mosaicml/llm-foundry
  git_branch: v0.8.0
  pip_install: .[gpu,openai]
  ssh_clone: false 

# # Uncomment to enable WandB Integration
# - integration_type: wandb
#   entity: my-entity
#   project: my-project

# # Uncomment to enable MLFlow Integration
# - integration_type: mlflow
#   experiment_name: /Users/[email protected]/my-experiment

command: |-
  cd llm-foundry/scripts
  composer eval/eval.py eval/yamls/dbrx-gauntlet/dbrx-9b.yaml \
  # loggers.mlflow={} \ # Uncomment to enable mlflow
  # loggers.wandb={} \ # Uncomment to enable wandb
    models[0].load_path=TODO_CHECKPOINT

To use:

  1. Uncomment integration & logger command to enable wandb and/or mlflow

  2. Change the name yaml/dbrx-gauntlet/dbrx-9b if not using dbrx-9b

  3. Change the tokenizer if not using the default

  4. TODO_CHECKPOINT: set as your checkpoint path

  5. Run mcli run -f this.yaml --gpus 8 --cluster <value>

This is meant to be a template to get started, feel free to modify or specify additional tasks!


Help us improve!#

We’re eager to hear your feedback! If our Pretraining API doesn’t meet your needs, please let us know so we can prioritize future enhancements to better support you. Your input is invaluable in shaping our API’s growth and development!