Pretraining#

Feature in Preview

Pretraining is currently in preview, and is liable to change significantly in the near future.

Pretraining gives you end-to-end ownership of your custom model.

Our pretraining API offers:

  1. A simple interface to our training stack to perform full model pretraining.

  2. Optimal, default hyperparameters and model training setup.

  3. Pretrained model checkpoints saved to remote store of your choice.

  4. Ability to customize your tokenizer.

  5. Use a mix of datasets for your model.

  6. Evaluate your model as your model pretrains.

We recommend to try pretraining if:

  • You have tried finetuning an existing model and want better results.

  • You have tried prompt engineering on an existing model and want better results.

  • You want full ownership over a custom model for data privacy.

  • You want to use your own tokenizer or vocabulary, specially for support in other languages.

Setup#

Before getting started with pretraining, make sure you have configured MosaicML access.

Data preparation and credentials#

The training data and eval data format required by the API is raw text, converted to MDS format.

Note that we currently don’t support reading from Unity Catalog datasets as an input yet and we are working on it.

Supported data sources#

If you are using a remote object store as the source of your training data, you must first create an MCLI secret with the credentials to access your data.

Note that the folder to save your checkpoints must be a remote object store, which will also require secrets configurations. We support the following data sources:

Data Source

Example

MCLI Secret

AWS S3

s3://bucket/...

AWS S3

OCI

oci://bucket/...

OCI

GCP

gs://bucket/...

GCP

Supported models#

We currently support pretraining on the following suite of models with a maximum context length of 4096:

Model

Parameters

Suggested tokens

Time to train with suggested tokens

databricks/dbrx-9b

9.2B total, 2.6B active

200B tok

1 day (128 h100s)

databricks/dbrx-18b

18.6B total, 5.2B active

400B tok

4 days (128 h100s)

databricks/dbrx-35b

35.7B total, 9.9B active

700B tok

7 days (256 h100s)

databricks/dbrx-73b

73.5B total, 20.1B active

1.5T tok

16 days (512 h100s)

A quick example#

Here is a minimal example of pretraining a model on a dataset.

model: databricks/dbrx-9b
train_data: s3://<my-bucket>/data
save_folder: s3://<my-bucket>/checkpoints
compute: 
  cluster: <cluster_name>
  gpus: 128

You can then launch this run and save checkpoints to your S3 bucket with the following command:

mcli train -f pretrain.yaml

You can also pass overrides to the yaml via the CLI command for the mandatory and optional fields:

mcli train -f pretrain.yaml \
  --model databricks/dbrx-9b \
  --train-data s3://<my-bucket>/data \
  --training-duration 10000tok

Experiment tracking#

We support both MLflow and WandB as experiment trackers to monitor and visualize the metrics for your pretraining run. Set experiment_tracker to contain the configuration for the tracker you want to use.

MLflow#

Provide the full path for the experiment, including the experiment name. In Databricks Managed MLflow, this will be a workspace path resembling /Users/example@domain.com/my_experiment. You can also provide a model_registry_path for model deployment. Make sure to configure your Databricks secret.

experiment_tracker:
  mlflow:
    experiment_path: /Users/[email protected]/my_experiment
    model_registry_path: catalog.schema | catalog.schema.model_name # optional

Weights & Biases#

Include both project name and entity name in your configuration, and make sure to set up your WandB secret.

experiment_tracker:
  wandb:
    project: my-project
    entity: my-entity

Launching a pretraining run#

Calling the pretrain API launches your run using the SDK, while the yaml needs to be launched with mcli train -f <your-yaml>. Refer to the example above and see the Pretraining Schema for more information about the parameters for the pretraining API.

The SDK result is a Run object.

mcli.Run(run_uid, name, status, created_at, updated_at, created_by, priority, preemptible, retry_on_system_failure, cluster, gpus, gpu_type, cpus, node_count, latest_resumption, is_deleted, run_type, max_retries=None, reason=None, nodes=<factory>, submitted_config=None, metadata=None, last_resumption_id=None, resumptions=<factory>, events=<factory>, lifecycle=<factory>, image=None, max_duration=None, _required_properties=('id', 'name', 'status', 'createdAt', 'updatedAt', 'reason', 'createdByEmail', 'priority', 'preemptible', 'retryOnSystemFailure', 'resumptions', 'isDeleted', 'runType'))[source]

A run that has been launched on the MosaicML platform

Parameters
  • run_uid (str) – Unique identifier for the run

  • name (str) – User-defined name of the run

  • status (RunStatus) – Status of the run at a moment in time

  • created_at (datetime) – Date and time when the run was created

  • updated_at (datetime) – Date and time when the run was last updated

  • created_by (str) – Email of the user who created the run

  • priority (str) – Priority of the run; defaults to auto but can be updated to low or lowest

  • preemptible (bool) – Whether the run can be stopped and re-queued by higher priority jobs

  • retry_on_system_failure (bool) – Whether the run should be retried on system failure

  • cluster (str) – Cluster the run is running on

  • gpus (int) – Number of GPUs the run is using

  • gpu_type (str) – Type of GPU the run is using

  • cpus (int) – Number of CPUs the run is using

  • node_count (int) – Number of nodes the run is using

  • latest_resumption (Resumption) – Latest resumption of the run

  • max_retries (Optional[int]) – Maximum number of times the run can be retried

  • reason (Optional[str]) – Reason the run was stopped

  • nodes (List[:class:`~mcli.api.model.run.Node]`) – Nodes the run is using

  • submitted_config (Optional[:class:`~mcli.models.run_config.RunConfig]`) – Submitted run configuration

  • metadata (Optional[Dict[str, Any]]) – Metadata associated with the run

  • last_resumption_id (Optional[str]) – ID of the last resumption of the run

  • resumptions (List[:class:`~mcli.api.model.run.Resumption]`) – Resumptions of the run

  • lifecycle (List[:class:`~mcli.api.model.run.RunLifecycle]`) – Lifecycle of the run

  • image (Optional[str]) – Image the run is using

See the Pretraining CLI and Pretraining SDK for more information on how to interact with your pretraining runs.

Looking for more configurability over the model training? Try creating a training run instead and see the LLM foundry pretraining documentation for more details.


Want to evaluate your model?#

Our pretraining API provides a lightweight solution that runs evaluation during pretraining, under the eval.data_path, which should point to the remote location of your evaluation data (e.g. s3://my-bucket/my-data.jsonl). This should be in the same format as your training data, see the file format instructions above. We will compute Cross Entropy and Perplexity on this evaluation data.

For complete evaluation after pretraining, see our LLM evaluation framework for open-source In-context learning (ICL) tasks.


Help us improve!#

We’re eager to hear your feedback! If our Pretraining API doesn’t meet your needs, please let us know so we can prioritize future enhancements to better support you. Your input is invaluable in shaping our API’s growth and development!