Source code for composer.callbacks.generate

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Periodically log generations from a set of prompts."""

import logging
import time
from typing import Any, List, Optional, Union, cast

from composer.core import Callback, Event, State, Time, get_precision_context
from composer.loggers import Logger
from composer.models import HuggingFaceModel
from composer.utils import create_interval_scheduler, dist
from composer.utils.import_helpers import MissingConditionalImportError

log = logging.getLogger(__name__)


[docs]class Generate(Callback): """Periodically log generations from a set of prompts. Args: prompts (List[str]): The list of prompts you would like to produce generations for interval (Union[str, int, :class:`.Time`]): The interval describing how often checkpoints should be saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. batch_size (Optional[int]): Size of a prompt batch for generation. If None, defaults to the number of prompts. kwargs: All kwargs will be passed along to the call to generate. This is for things like `do_sample`, `top_p`, etc """ def __init__( self, prompts: List[str], interval: Union[str, int, Time], batch_size: Optional[int] = None, **kwargs: Any, ): try: import transformers except ImportError as e: raise MissingConditionalImportError( extra_deps_group='nlp', conda_package='transformers', conda_channel='conda-forge', ) from e del transformers self.prompts = prompts self.generate_kwargs = kwargs self.batch_size = batch_size if batch_size is not None else len(prompts) self.check_interval = create_interval_scheduler(interval, include_end_of_training=True) self.last_generate_batch: Optional[Time] = None def run_event(self, event: Event, state: State, logger: Logger) -> None: if state.get_elapsed_duration( ) is not None and self.check_interval(state, event) and self.last_generate_batch != state.timestamp.batch: start = time.time() self.generate(state, logger) diff = time.time() - start log.info(f'Generate callback ran in {diff} seconds for {len(self.prompts)} prompts') def generate(self, state: State, logger: Logger): self.last_generate_batch = state.timestamp.batch model = state.model.module if state.is_model_ddp else state.model if not isinstance(model, HuggingFaceModel): # TODO: Extend to support any models that have a generate method. raise ValueError(f'Expected HuggingFaceModel, but got {model.__class__.__name__}') if not hasattr(model, 'tokenizer') or model.tokenizer is None: raise ValueError( f'Model {model.__class__.__name__} does not have a tokenizer which is required for generation.', ) tokenizer = model.tokenizer from transformers import PreTrainedTokenizerBase tokenizer = cast(PreTrainedTokenizerBase, tokenizer) # Set to evaluation mode and stash the original mode. original_mode = model.training model.eval() device = state.device # Stash the original value of padding_side because generation requires left padding original_padding_side = tokenizer.padding_side tokenizer.padding_side = 'left' if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id tokenized_input = tokenizer(self.prompts, return_tensors='pt', padding=True) all_input_ids = tokenized_input['input_ids'] all_attn_masks = tokenized_input['attention_mask'] output_token_ids = [] # dummy forward call needed for FSDP to work consistently model.dummy_forward_called = False n_prompts = len(self.prompts) for start in range(0, n_prompts, self.batch_size): end = min(start + self.batch_size, n_prompts) input_ids = all_input_ids[start:end] # pyright: ignore[reportGeneralTypeIssues] attn_mask = all_attn_masks[start:end] # pyright: ignore[reportGeneralTypeIssues] # Move batch to device. input_ids = device.tensor_to_device(input_ids) attn_mask = device.tensor_to_device(attn_mask) with get_precision_context(state.precision): output_token_ids.extend( model.generate( # type: ignore input_ids=input_ids, attention_mask=attn_mask, synced_gpus=dist.get_world_size() > 1, **self.generate_kwargs, ), ) if dist.get_global_rank() == 0: # Process prompts and outputs into a table. rows = [] input_tokens_len = all_input_ids.shape[1] # pyright: ignore[reportGeneralTypeIssues] for i, prompt in enumerate(self.prompts): output_tokens = output_token_ids[i][input_tokens_len:] output_text = tokenizer.decode(output_tokens, skip_special_tokens=True) rows.append([prompt, output_text]) logger.log_table(columns=['prompt', 'generation'], rows=rows, name='generations') tokenizer.padding_side = original_padding_side model.train(mode=original_mode)