InMemoryLogger#
- class composer.loggers.InMemoryLogger[source]#
Logs metrics to dictionary objects that persist in memory throughout training.
Useful for collecting and plotting data inside notebooks.
- Example usage:
from composer.loggers import InMemoryLogger from composer.trainer import Trainer logger = InMemoryLogger( ) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", optimizers=[optimizer], loggers=[logger] ) # Get data from logger. If you are using multiple loggers, be sure to confirm # which index in trainer.logger.destinations contains your desired logger. logged_data = trainer.logger.destinations[0].data
- data#
Mapping of a logged key to a (
Timestamp
, logged value) tuple. This dictionary contains all logged data.
- get_timeseries(metric)[source]#
Returns logged data as dict containing values of a desired metric over time.
- Parameters
metric (str) โ Metric of interest. Must be present in self.data.keys().
- Returns
timeseries (dict[str, Any]) โ Dictionary in which one key is
metric
, and the associated value is a list of values of that metric. The remaining keys are each a unit of time, and the associated values are each a list of values of that time unit for the corresponding index of the metric. For example: >>> InMemoryLogger.get_timeseries(metric=โaccuracy/valโ) {โaccuracy/valโ: [31.2, 45.6, 59.3, 64.7, โepochโ: [1, 2, 3, 4, โฆ], โฆ], โbatchโ: [49, 98, 147, 196, โฆ], โฆ}
Example
import matplotlib.pyplot as plt from composer.loggers import InMemoryLogger from composer.core.time import Time, Timestamp in_mem_logger = InMemoryLogger() trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", optimizers=[optimizer], loggers=[in_mem_logger] ) # Populate the logger with data for b in range(0,3): datapoint = b * 3 in_mem_logger.log_metrics({"accuracy/val": datapoint}) timeseries = in_mem_logger.get_timeseries("accuracy/val") plt.plot(timeseries["batch"], timeseries["accuracy/val"]) plt.xlabel("Batch") plt.ylabel("Validation Accuracy")