Pretraining#
Feature in Preview
Pretraining is currently in preview, and is liable to change significantly in the near future.
Pretraining gives you end-to-end ownership of your custom model.
Our pretraining API offers:
A simple interface to our training stack to perform full model pretraining.
Optimal, default hyperparameters and model training setup.
Pretrained model checkpoints saved to remote store of your choice.
Ability to customize your tokenizer.
Use a mix of datasets for your model.
Evaluate your model as your model pretrains.
We recommend to try pretraining if:
You have tried finetuning an existing model and want better results.
You have tried prompt engineering on an existing model and want better results.
You want full ownership over a custom model for data privacy.
You want to use your own tokenizer or vocabulary, specially for support in other languages.
Setup#
Before getting started with pretraining, make sure you have configured MosaicML access.
Data preparation and credentials#
The training data and eval data format required by the API is raw text, converted to MDS format.
Note that we currently donβt support reading from Unity Catalog datasets as an input yet and we are working on it.
Supported data sources#
If you are using a remote object store as the source of your training data, you must first create an MCLI secret with the credentials to access your data.
Note that the folder to save your checkpoints must be a remote object store, which will also require secrets configurations. We support the following data sources:
Supported models#
We currently support pretraining on the following suite of models with a maximum context length of 4096:
Model |
Parameters |
Suggested tokens |
Time to train with suggested tokens |
---|---|---|---|
|
9.2B total, 2.6B active |
200B tok |
1 day (128 h100s) |
|
18.6B total, 5.2B active |
400B tok |
4 days (128 h100s) |
|
35.7B total, 9.9B active |
700B tok |
7 days (256 h100s) |
|
73.5B total, 20.1B active |
1.5T tok |
16 days (512 h100s) |
A quick example#
Here is a minimal example of pretraining a model on a dataset.
model: databricks/dbrx-9b
train_data: s3://<my-bucket>/data
save_folder: s3://<my-bucket>/checkpoints
compute:
cluster: <cluster_name>
gpus: 128
You can then launch this run and save checkpoints to your S3 bucket with the following command:
mcli train -f pretrain.yaml
You can also pass overrides to the yaml via the CLI command for the mandatory and optional fields:
mcli train -f pretrain.yaml \
--model databricks/dbrx-9b \
--train-data s3://<my-bucket>/data \
--training-duration 10000tok
Experiment tracking#
We support both MLflow and WandB as experiment trackers to monitor and visualize the metrics for your pretraining run. Set experiment_tracker
to contain the configuration for the tracker you want to use.
MLflow#
Provide the full path for the experiment, including the experiment name. In Databricks Managed MLflow, this will be a workspace path resembling
/Users/example@domain.com/my_experiment
. You can also provide a model_registry_path
for model deployment. Make sure to configure your Databricks secret.
experiment_tracker:
mlflow:
experiment_path: /Users/[email protected]/my_experiment
model_registry_path: catalog.schema | catalog.schema.model_name # optional
Weights & Biases#
Include both project name and entity name in your configuration, and make sure to set up your WandB secret.
experiment_tracker:
wandb:
project: my-project
entity: my-entity
Launching a pretraining run#
Calling the pretrain
API launches your run using the SDK, while the yaml needs to be launched with mcli train -f <your-yaml>
. Refer to the example above and see the Pretraining Schema for more information about the parameters for the pretraining API.
The SDK result is a Run
object.
- mcli.Run(run_uid, name, status, created_at, updated_at, created_by, priority, preemptible, retry_on_system_failure, cluster, gpus, gpu_type, cpus, node_count, latest_resumption, is_deleted, run_type, max_retries=None, reason=None, nodes=<factory>, submitted_config=None, metadata=None, last_resumption_id=None, resumptions=<factory>, events=<factory>, lifecycle=<factory>, image=None, max_duration=None, _required_properties=('id', 'name', 'status', 'createdAt', 'updatedAt', 'reason', 'createdByEmail', 'priority', 'preemptible', 'retryOnSystemFailure', 'resumptions', 'isDeleted', 'runType'))[source]
A run that has been launched on the MosaicML platform
- Parameters
run_uid (str) β Unique identifier for the run
name (str) β User-defined name of the run
status (
RunStatus
) β Status of the run at a moment in timecreated_at (datetime) β Date and time when the run was created
updated_at (datetime) β Date and time when the run was last updated
created_by (str) β Email of the user who created the run
priority (str) β Priority of the run; defaults to auto but can be updated to low or lowest
preemptible (bool) β Whether the run can be stopped and re-queued by higher priority jobs
retry_on_system_failure (bool) β Whether the run should be retried on system failure
cluster (str) β Cluster the run is running on
gpus (int) β Number of GPUs the run is using
gpu_type (str) β Type of GPU the run is using
cpus (int) β Number of CPUs the run is using
node_count (int) β Number of nodes the run is using
latest_resumption (
Resumption
) β Latest resumption of the runmax_retries (Optional[int]) β Maximum number of times the run can be retried
reason (Optional[str]) β Reason the run was stopped
nodes (List[:class:`~mcli.api.model.run.Node]`) β Nodes the run is using
submitted_config (Optional[:class:`~mcli.models.run_config.RunConfig]`) β Submitted run configuration
metadata (Optional[Dict[str, Any]]) β Metadata associated with the run
last_resumption_id (Optional[str]) β ID of the last resumption of the run
resumptions (List[:class:`~mcli.api.model.run.Resumption]`) β Resumptions of the run
lifecycle (List[:class:`~mcli.api.model.run.RunLifecycle]`) β Lifecycle of the run
image (Optional[str]) β Image the run is using
See the Pretraining CLI and Pretraining SDK for more information on how to interact with your pretraining runs.
Looking for more configurability over the model training? Try creating a training run instead and see the LLM foundry pretraining documentation for more details.
Want to evaluate your model?#
Our pretraining API provides a lightweight solution that runs evaluation during pretraining, under the eval.data_path
, which should point to the remote location of your evaluation data (e.g. s3://my-bucket/my-data.jsonl
). This should be in the same format as your training data, see the file format instructions above. We will compute Cross Entropy and Perplexity on this evaluation data.
For complete evaluation after pretraining, see our LLM evaluation framework for open-source In-context learning (ICL) tasks.
Help us improve!#
Weβre eager to hear your feedback! If our Pretraining API doesnβt meet your needs, please let us know so we can prioritize future enhancements to better support you. Your input is invaluable in shaping our APIβs growth and development!