Source code for composer.loggers.slack_logger

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Log metrics to slack, using Slack postMessage api."""

from __future__ import annotations

import itertools
import logging
import os
import re
import time
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Union

from composer.core.time import Time, TimeUnit
from composer.loggers.logger import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import MissingConditionalImportError

    from composer.core import State

log = logging.getLogger(__name__)

__all__ = ['SlackLogger']

[docs]class SlackLogger(LoggerDestination): """Log metrics to slack, using Slack's postMessage api - First export 2 environment variable to use this logger. 1. SLACK_LOGGING_API_KEY: To get app credentials, follow tutorial here - 2. SLACK_LOGGING_CHANNEL_ID: Channel id to send the message (Open slack channel in web browser to look this up). Next write script to output metrics / hparams / traces to slack channel. See example below. .. code-block:: python trainer = Trainer( model=mnist_model(num_classes=10), train_dataloader=train_dataloader, max_duration='2ep', algorithms=[ LabelSmoothing(smoothing=0.1), CutMix(alpha=1.0), ChannelsLast(), ], loggers=[ SlackLogger( formatter_func=(lambda data: [{ 'type': 'section', 'text': { 'type': 'mrkdwn', 'text': f'*{k}:* {v}' } } for k, v in data.items()]), include_keys=['loss/train/total'], interval_in_seconds=1 ), ], ) Args: formatter_func ((...) -> Any | None): A formatter function that returns list of blocks to be sent to slack. include_keys (Sequence[str]): A sequence of metric/logs/traces keys to include in the message. log_interval: (int | str | Time): How frequently to log. (default: ``'1ba'``) max_logs_per_message (int)(default:50): Maximum number of logs to send in a single message. Note that no more than 50 items are allowed to send in a single message. If more than 50 items are stored in buffer, the message flushed without waiting the full time interval. """ def __init__( self, include_keys: Sequence[str] = (), formatter_func: Optional[Callable[..., list[dict[str, Any]]]] = None, log_interval: Union[int, str, Time] = '1ba', max_logs_per_message: int = 50, slack_logging_api_key: Optional[str] = None, channel_id: Optional[str] = None, ) -> None: try: import slack_sdk self.client = slack_sdk.WebClient() del slack_sdk except ImportError as e: raise MissingConditionalImportError('slack_logger', 'slack_sdk', None) from e self.slack_logging_api_key = os.environ.get( 'SLACK_LOGGING_API_KEY', None, ) if slack_logging_api_key is None else slack_logging_api_key self.channel_id = os.environ.get('SLACK_LOGGING_CHANNEL_ID', None) if channel_id is None else channel_id if self.slack_logging_api_key is None: print('WARNING: SLACK_LOGGING_API_KEY must be set as environment variable') if self.channel_id is None: print('WARNING: SLACK_LOGGING_CHANNEL_ID must be set as environment variable') self.formatter_func = formatter_func if len(include_keys) == 0: print('WARNING: The slack logger `include_keys` argument must be a non-empty list of strings.') # Create a regex of all keys to include self.regex_all_keys = '(' + ')|('.join(include_keys) + ')' self.log_interval = Time.from_input(log_interval, TimeUnit.EPOCH) if self.log_interval.unit not in (TimeUnit.EPOCH, TimeUnit.BATCH): raise ValueError('The `slack logger log_interval` argument must have units of EPOCH or BATCH.') self.log_dict, self.last_log_time = {}, time.time() self.max_logs_per_message = min(max_logs_per_message, 50) def _log_to_buffer( self, data: dict[str, Any], **kwargs, # can be used to pass additional arguments to the formatter function (eg for headers) ): """Flush the buffer to slack if the buffer size exceeds max_logs_per_message. Buffer will replace existing keys with updated values if keys exist. Otherwise, add new key-value pairs. If max_logs_per_message is exceeded, flush buffer. Otherwise, wait for the next log_interval (batch end or epoch end) to flush the buffer. """ filtered_data = {k: v for k, v in data.items() if re.match(self.regex_all_keys, k) is not None} self.log_dict.update(filtered_data) if len(self.log_dict.keys()) >= self.max_logs_per_message: self._flush_logs_to_slack(**kwargs) def _default_log_bold_key_normal_value_pair_with_header( self, data: dict[str, Any], **kwargs, ) -> list[dict[str, Any]]: """Default formatter function if no formatter func is specified. This function will: 1. Log the key-value pairs in bold (key) and normal (value) text. 2. When logging metrics, set the step number as the header of the section. Args: data (dict[str, Any]): Data to be logged. **kwargs: Additional arguments to be passed to the formatter function (Only supports "header" argument now) Returns: list[dict[str, str]]: list of blocks to be sent to Slack. """ blocks = [{'type': 'section', 'text': {'type': 'mrkdwn', 'text': f'*{k}:* {v}'}} for k, v in data.items()] if len(blocks) > 0 and 'header' in kwargs: header = kwargs['header'] blocks.append({'type': 'header', 'text': {'type': 'plain_text', 'text': f'{header}'}}) return blocks def log_metrics(self, metrics: dict[str, Any], step: Optional[int] = None) -> None: self._log_to_buffer(data=metrics, header=step) def log_hyperparameters(self, hyperparameters: dict[str, Any]): self._log_to_buffer(data=hyperparameters) def log_traces(self, traces: dict[str, Any]): self._log_to_buffer(data=traces) def epoch_end(self, state: State, logger: Logger) -> None: cur_epoch = int(state.timestamp.epoch) # epoch gets incremented right before EPOCH_END unit = self.log_interval.unit if unit == TimeUnit.EPOCH and (cur_epoch % int(self.log_interval) == 0 or cur_epoch == 1): self._flush_logs_to_slack() def batch_end(self, state: State, logger: Logger) -> None: cur_batch = int(state.timestamp.batch) unit = self.log_interval.unit if unit == TimeUnit.BATCH and (cur_batch % int(self.log_interval) == 0 or cur_batch == 1): self._flush_logs_to_slack() def close(self, state: State, logger: Logger) -> None: self._flush_logs_to_slack() def _flush_logs_to_slack(self, **kwargs) -> None: """Flush buffered metadata to MosaicML. Format slack messages through rich message layouts created using Slack Blocks Kit. See documentation here: """ inx = 0 while inx < len(self.log_dict.keys()): max_log_entries_dict = dict(itertools.islice(self.log_dict.items(), inx, inx + self.max_logs_per_message)) self._format_and_send_blocks_to_slack(max_log_entries_dict, **kwargs) inx += self.max_logs_per_message self.log_dict = {} # reset log_dict def _format_and_send_blocks_to_slack( self, log_entries: dict[str, Any], **kwargs, ): blocks = self.formatter_func( log_entries, **kwargs, ) if self.formatter_func is not None else self._default_log_bold_key_normal_value_pair_with_header( log_entries, **kwargs, ) try: channel_id = self.channel_id slack_logging_key = self.slack_logging_api_key if channel_id is None: raise TypeError('SLACK_LOGGING_CHANNEL_ID cannot be None.') if slack_logging_key is None: raise TypeError('SLACK_LOGGING_API_KEY cannot be None') self.client.chat_postMessage( token=f'{self.slack_logging_api_key if self.slack_logging_api_key else ""}', channel=channel_id, blocks=blocks, text=f'Logged {len(log_entries)} items to Slack', ) except Exception as e: log.error(f'Error logging to Slack: {e}')