InMemoryLogger#
- class composer.loggers.InMemoryLogger[source]#
- Logs metrics to dictionary objects that persist in memory throughout training. - Useful for collecting and plotting data inside notebooks. - Example usage:
- from composer.loggers import InMemoryLogger from composer.trainer import Trainer logger = InMemoryLogger( ) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", optimizers=[optimizer], loggers=[logger] ) # Get data from logger. If you are using multiple loggers, be sure to confirm # which index in trainer.logger.destinations contains your desired logger. logged_data = trainer.logger.destinations[0].data 
 - data#
- Mapping of a logged key to a ( - Timestamp, logged value) tuple. This dictionary contains all logged data.
 - get_timeseries(metric)[source]#
- Returns logged data as dict containing values of a desired metric over time. - Parameters
- metric (str) โ Metric of interest. Must be present in self.data.keys(). 
- Returns
- timeseries (Dict[str, Any]) โ Dictionary in which one key is - metric, and the associated value is a list of values of that metric. The remaining keys are each a unit of time, and the associated values are each a list of values of that time unit for the corresponding index of the metric. For example: >>> InMemoryLogger.get_timeseries(metric=โaccuracy/valโ) {โaccuracy/valโ: [31.2, 45.6, 59.3, 64.7, โepochโ: [1, 2, 3, 4, โฆ], โฆ], โbatchโ: [49, 98, 147, 196, โฆ], โฆ}
 - Example - import matplotlib.pyplot as plt from composer.loggers import InMemoryLogger from composer.core.time import Time, Timestamp in_mem_logger = InMemoryLogger() trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", optimizers=[optimizer], loggers=[in_mem_logger] ) # Populate the logger with data for b in range(0,3): datapoint = b * 3 in_mem_logger.log_metrics({"accuracy/val": datapoint}) timeseries = in_mem_logger.get_timeseries("accuracy/val") plt.plot(timeseries["batch"], timeseries["accuracy/val"]) plt.xlabel("Batch") plt.ylabel("Validation Accuracy")