# Copyright 2021 MosaicML. All Rights Reserved.
"""Logger Hyperparameter classes."""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional
import yahp as hp
from composer.loggers.file_logger import FileLogger
from composer.loggers.in_memory_logger import InMemoryLogger
from composer.loggers.logger import LogLevel
from composer.loggers.logger_destination import LoggerDestination
from composer.loggers.object_store_logger import ObjectStoreLogger
from composer.loggers.progress_bar_logger import ProgressBarLogger
from composer.loggers.wandb_logger import WandBLogger
from composer.utils import ObjectStoreHparams, dist, import_object
__all__ = [
"FileLoggerHparams",
"InMemoryLoggerHparams",
"LoggerDestinationHparams",
"ProgressBarLoggerHparams",
"WandBLoggerHparams",
"ObjectStoreLoggerHparams",
"logger_registry",
]
[docs]@dataclass
class LoggerDestinationHparams(hp.Hparams, ABC):
"""Base class for logger callback hyperparameters.
Logger parameters that are added to :class:`~.trainer_hparams.TrainerHparams` (e.g. via YAML or the CLI) are
initialized in the training loop.
"""
[docs] @abstractmethod
def initialize_object(self) -> LoggerDestination:
"""Initializes the logger."""
pass
[docs]@dataclass
class FileLoggerHparams(LoggerDestinationHparams):
""":class:`~composer.loggers.file_logger.FileLogger`
hyperparameters.
See :class:`~composer.loggers.file_logger.FileLogger` for documentation.
Args:
filename (str, optional): See :class:`~composer.loggers.file_logger.FileLogger`.
artifact_name (str, optional): See :class:`~composer.loggers.file_logger.FileLogger`.
capture_stdout (bool, optional): See :class:`~composer.loggers.file_logger.FileLogger`.
capture_stderr (bool, optional): See :class:`~composer.loggers.file_logger.FileLogger`.
buffer_size (int, optional): See
:class:`~composer.loggers.file_logger.FileLogger`.
log_level (LogLevel, optional): See
:class:`~composer.loggers.file_logger.FileLogger`.
log_interval (int, optional): See
:class:`~composer.loggers.file_logger.FileLogger`.
flush_interval (int, optional): See
:class:`~composer.loggers.file_logger.FileLogger`.
"""
log_level: LogLevel = hp.optional("The maximum verbosity to log. Default: EPOCH", default=LogLevel.EPOCH)
filename: str = hp.optional("Filename format string for the logfile.", default='{run_name}/logs-rank{rank}.txt')
artifact_name: Optional[str] = hp.optional("Artifact name format string for the logfile.", default=None)
capture_stdout: bool = hp.optional("Whether to capture writes to `stdout`", default=True)
capture_stderr: bool = hp.optional("Whether to capture writes to `stderr`", default=True)
buffer_size: int = hp.optional("Number of bytes to buffer. Defaults to 1 for line-buffering. "
"See https://docs.python.org/3/library/functions.html#open",
default=1) # line buffering. Python's default is -1.
flush_interval: int = hp.optional(
"Frequency to flush the file, relative to the ``log_level``. "
"Defaults to 100 of the unit of ``log_level``.",
default=100)
log_interval: int = hp.optional(
"Frequency to record log messages, relative to the ``log_level``."
"Defaults to 1 (record all messages).",
default=1)
def initialize_object(self) -> FileLogger:
return FileLogger(**asdict(self))
[docs]@dataclass
class WandBLoggerHparams(LoggerDestinationHparams):
""":class:`~composer.loggers.wandb_logger.WandBLogger` hyperparameters.
Args:
project (str, optional): WandB project name.
group (str, optional): WandB group name.
name (str, optional): WandB run name.
If not specified, the :attr:`.Logger.run_name` will be used.
entity (str, optional): WandB entity name.
tags (str, optional): WandB tags, comma-separated.
config (Dict[str, Any], optional): WandB run configuration.
flatten_config (bool, optional): Whether to flatten the run config. (default: ``False``)
log_artifacts (bool, optional): See :class:`~composer.loggers.wandb_logger.WandBLogger`.
rank_zero_only (bool, optional): See :class:`~composer.loggers.wandb_logger.WandBLogger`.
extra_init_params (dict, optional): See
:class:`~composer.loggers.wandb_logger.WandBLogger`.
"""
project: Optional[str] = hp.optional(doc="wandb project name", default=None)
group: Optional[str] = hp.optional(doc="wandb group name", default=None)
name: Optional[str] = hp.optional(doc="wandb run name", default=None)
entity: Optional[str] = hp.optional(doc="wandb entity", default=None)
tags: Optional[str] = hp.optional(doc="wandb tags comma separated", default=None)
log_artifacts: bool = hp.optional(doc="Whether to log artifacts", default=False)
rank_zero_only: bool = hp.optional("Whether to log on rank zero only", default=True)
extra_init_params: Dict[str, Any] = hp.optional(doc="wandb parameters", default_factory=dict)
config: Dict[str, Any] = hp.optional(doc="Wandb run configuration", default_factory=dict)
flatten_config: bool = hp.optional(
doc="Whether to flatten the config, which can make nested fields easier to visualize and query.", default=False)
def initialize_object(self) -> WandBLogger:
tags = None
if self.tags is not None:
tags = list(set([x.strip() for x in self.tags.split(",") if x.strip() != ""]))
config_dict = self.config
if "config" in self.extra_init_params:
config_dict = self.extra_init_params["config"]
if self.flatten_config:
config_dict = self._flatten_dict(config_dict)
if self.rank_zero_only:
name = self.name
group = self.group
else:
name = f"{self.name} [RANK_{dist.get_global_rank()}]"
group = self.group if self.group else self.name
init_params = {
"project": self.project,
"name": name,
"group": group,
"entity": self.entity,
"tags": tags,
"config": config_dict,
}
init_params.update(self.extra_init_params)
return WandBLogger(
log_artifacts=self.log_artifacts,
rank_zero_only=self.rank_zero_only,
init_params=init_params,
)
@classmethod
def _flatten_dict(cls, data: Dict[str, Any], _prefix: List[str] = []) -> Dict[str, Any]:
"""Flattens a dictionary with list or sub dicts to have dot syntax.
.. testcode::
>>> config = {
... "sub_dict":{
... "sub_list":[
... "sub_sub_dict":{
... "foo": 0,
... "bar": "baz"
... }
... ]
... },
... "hello": "world"
... }
>>> _flatten_dict(config)
{
'sub_dict.sub_list.sub_sub_dict.foo': 0,
'sub_dict.sub_list.sub_sub_dict.bar': 'baz',
'hello': 'world',
}
"""
all_items = {}
for key, val in data.items():
key_items = _prefix + [key]
key_name = ".".join(key_items)
if isinstance(val, dict):
all_items.update(cls._flatten_dict(val, key_items))
elif isinstance(val, list):
found_sub_dicts = False
for item in val:
if isinstance(item, dict):
found_sub_dicts = True
for sub_key, sub_val in item.items():
if isinstance(sub_val, dict):
all_items.update(cls._flatten_dict(sub_val, key_items + [sub_key]))
else:
all_items.update({sub_key: sub_val})
if not found_sub_dicts:
all_items[key_name] = val
else:
all_items[key_name] = val
return all_items
[docs]@dataclass
class ProgressBarLoggerHparams(LoggerDestinationHparams):
""":class:`~composer.loggers.progress_bar_logger.ProgressBarLogger`
hyperparameters.
.. deprecated:: 0.6.0
This class is deprecated. Instead, please specify the :class:`.ProgressBarLogger` arguments
directly in the :class:`~composer.trainer.trainer_hparams.TrainerHparams`. This class will be removed
in v0.7.0.
Args:
progress_bar (bool, optional): See :class:`.ProgressBarLogger`.
log_to_console (bool, optional): See :class:`.ProgressBarLogger`.
console_log_level (bool, optional): See :class:`.ProgressBarLogger`.
stream (bool, optional): See :class:`.ProgressBarLogger`.
"""
progress_bar: bool = hp.optional("Whether to show a progress bar.", default=True)
log_to_console: Optional[bool] = hp.optional("Whether to print log statements to the console.", default=None)
console_log_level: LogLevel = hp.optional("The maximum log level.", default=LogLevel.EPOCH)
stream: str = hp.optional("The stream at which to write the progress bar and log statements.", default="stderr")
def initialize_object(self) -> ProgressBarLogger:
return ProgressBarLogger(
progress_bar=self.progress_bar,
log_to_console=self.log_to_console,
console_log_level=self.console_log_level,
stream=self.stream,
)
[docs]@dataclass
class InMemoryLoggerHparams(LoggerDestinationHparams):
""":class:`~composer.loggers.in_memory_logger.InMemoryLogger`
hyperparameters.
Args:
log_level (str or LogLevel, optional):
See :class:`~composer.loggers.in_memory_logger.InMemoryLogger`.
"""
log_level: LogLevel = hp.optional("The maximum verbosity to log. Default: BATCH", default=LogLevel.BATCH)
def initialize_object(self) -> LoggerDestination:
return InMemoryLogger(log_level=self.log_level)
[docs]@dataclass
class ObjectStoreLoggerHparams(LoggerDestinationHparams):
""":class:`~composer.loggers.in_memory_logger.InMemoryLogger`
hyperparameters.
Args:
object_store_hparams (ObjectStoreHparams): The object store provider hparams.
should_log_artifact (str, optional): The path to a filter function which returns whether an artifact should be
logged. The path should be of the format ``path.to.module:filter_function_name``.
The function should take (:class:`~composer.core.state.State`, :class:`.LogLevel`, ``<artifact name>``).
The artifact name will be a string. The function should return a boolean indicating whether the artifact
should be logged.
.. seealso: :func:`composer.utils.import_helpers.import_object`
Setting this parameter to ``None`` (the default) will log all artifacts.
object_name (str, optional): See :class:`~composer.loggers.object_store_logger.ObjectStoreLogger`.
config_artifact_name (str, optional): See :class:`~composer.loggers.object_store_logger.ObjectStoreLogger`.
num_concurrent_uploads (int, optional): See :class:`~composer.loggers.object_store_logger.ObjectStoreLogger`.
upload_staging_folder (str, optional): See :class:`~composer.loggers.object_store_logger.ObjectStoreLogger`.
use_procs (bool, optional): See :class:`~composer.loggers.object_store_logger.ObjectStoreLogger`.
"""
object_store_hparams: ObjectStoreHparams = hp.required("Object store provider hparams.")
should_log_artifact: Optional[str] = hp.optional(
"Path to a filter function which returns whether an artifact should be logged.", default=None)
object_name: str = hp.optional("A format string for object names", default="{artifact_name}")
config_artifact_name: Optional[str] = hp.optional(
"Format string to describe how to store the training configuration.", default="{run_name}/config.yaml")
num_concurrent_uploads: int = hp.optional("Maximum number of concurrent uploads.", default=4)
use_procs: bool = hp.optional("Whether to perform file uploads in background processes (as opposed to threads).",
default=True)
upload_staging_folder: Optional[str] = hp.optional(
"Staging folder for uploads. If not specified, will use a temporary directory.", default=None)
def initialize_object(self) -> ObjectStoreLogger:
return ObjectStoreLogger(
provider=self.object_store_hparams.provider,
container=self.object_store_hparams.container,
provider_kwargs=self.object_store_hparams.get_provider_kwargs(),
object_name=self.object_name,
should_log_artifact=import_object(self.should_log_artifact)
if self.should_log_artifact is not None else None,
num_concurrent_uploads=self.num_concurrent_uploads,
upload_staging_folder=self.upload_staging_folder,
use_procs=self.use_procs,
)
logger_registry = {
"file": FileLoggerHparams,
"wandb": WandBLoggerHparams,
"progress_bar": ProgressBarLoggerHparams,
"in_memory": InMemoryLoggerHparams,
"object_store": ObjectStoreLoggerHparams,
}
"""The registry of all known :class:`.LoggerDestinationHparams`."""