# Copyright 2022 MosaicML Composer authors# SPDX-License-Identifier: Apache-2.0"""Logs metrics to dictionary objects that persist in memory throughout training.Useful for collecting and plotting data inside notebooks."""from__future__importannotationsimportcopyfromtypingimportTYPE_CHECKING,Any,Dict,List,Optional,TupleimportnumpyasnpfromtorchimportTensorfromcomposer.core.timeimportTimefromcomposer.loggers.loggerimportLoggerfromcomposer.loggers.logger_destinationimportLoggerDestinationifTYPE_CHECKING:fromcomposer.coreimportState,Timestamp__all__=['InMemoryLogger']
[docs]classInMemoryLogger(LoggerDestination):"""Logs metrics to dictionary objects that persist in memory throughout training. Useful for collecting and plotting data inside notebooks. Example usage: .. testcode:: from composer.loggers import InMemoryLogger from composer.trainer import Trainer logger = InMemoryLogger( ) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", optimizers=[optimizer], loggers=[logger] ) # Get data from logger. If you are using multiple loggers, be sure to confirm # which index in trainer.logger.destinations contains your desired logger. logged_data = trainer.logger.destinations[0].data Attributes: data (Dict[str, List[Tuple[Timestamp, Any]]]): Mapping of a logged key to a (:class:`~.time.Timestamp`, logged value) tuple. This dictionary contains all logged data. most_recent_values (Dict[str, Any]): Mapping of a key to the most recent value for that key. most_recent_timestamps (Dict[str, Timestamp]): Mapping of a key to the :class:`~.time.Timestamp` of the last logging call for that key. hyperparameters (Dict[str, Any]): Dictionary of all hyperparameters. """def__init__(self)->None:self.data:Dict[str,List[Tuple[Timestamp,Any]]]={}self.most_recent_values={}self.most_recent_timestamps:Dict[str,Timestamp]={}self.state:Optional[State]=Noneself.hyperparameters:Dict[str,Any]={}deflog_hyperparameters(self,hyperparameters:Dict[str,Any]):self.hyperparameters.update(hyperparameters)deflog_metrics(self,metrics:Dict[str,Any],step:Optional[int]=None)->None:assertself.stateisnotNonetimestamp=self.state.timestampcopied_metrics=copy.deepcopy(metrics)fork,vincopied_metrics.items():ifknotinself.data:self.data[k]=[]self.data[k].append((timestamp,v))self.most_recent_values.update(copied_metrics.items())self.most_recent_timestamps.update({k:timestampforkincopied_metrics})definit(self,state:State,logger:Logger)->None:self.state=state
[docs]defget_timeseries(self,metric:str)->Dict[str,Any]:"""Returns logged data as dict containing values of a desired metric over time. Args: metric (str): Metric of interest. Must be present in self.data.keys(). Returns: timeseries (Dict[str, Any]): Dictionary in which one key is ``metric``, and the associated value is a list of values of that metric. The remaining keys are each a unit of time, and the associated values are each a list of values of that time unit for the corresponding index of the metric. For example: >>> InMemoryLogger.get_timeseries(metric="accuracy/val") {"accuracy/val": [31.2, 45.6, 59.3, 64.7, "epoch": [1, 2, 3, 4, ...], ...], "batch": [49, 98, 147, 196, ...], ...} Example: .. testcode:: import matplotlib.pyplot as plt from composer.loggers import InMemoryLogger from composer.core.time import Time, Timestamp in_mem_logger = InMemoryLogger() trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", optimizers=[optimizer], loggers=[in_mem_logger] ) # Populate the logger with data for b in range(0,3): datapoint = b * 3 in_mem_logger.log_metrics({"accuracy/val": datapoint}) timeseries = in_mem_logger.get_timeseries("accuracy/val") plt.plot(timeseries["batch"], timeseries["accuracy/val"]) plt.xlabel("Batch") plt.ylabel("Validation Accuracy") """# Check that desired metric is in present dataifmetricnotinself.data.keys():raiseValueError(f'Invalid value for argument `metric`: {metric}. Requested ''metric is not present in self.data.keys().')timeseries={}# Iterate through datapointsfordatapointinself.data[metric]:timestamp,metric_value=datapointtimeseries.setdefault(metric,[]).append(metric_value)# Iterate through time units and add them all!forfield,timeintimestamp.get_state().items():time_value=time.valueifisinstance(time,Time)elsetime.total_seconds()timeseries.setdefault(field,[]).append(time_value)# Convert to numpy arraysfork,vintimeseries.items():ifisinstance(v[0],Tensor):v=Tensor(v).numpy()else:v=np.array(v)timeseries[k]=vreturntimeseries