Source code for workflow.ignite.handlers.metrics_logger

from tqdm import tqdm
from ignite.engine import Events
import pprint

from workflow.torch.is_float import is_float


[docs]class MetricsLogger: def __init__(self, name): self.name = name
[docs] def attach(self, engine, metric_names): engine.add_event_handler( Events.EPOCH_COMPLETED, lambda engine: self._print(engine, metric_names) )
def _print(self, engine, metric_names): if len(engine.state.metrics) >= 1: tqdm.write(f'{self.name}:') target_len = max(map(len, metric_names)) for metric_name in metric_names: value = engine.state.metrics.get(metric_name, None) if value is not None: padding = ' ' * (target_len - len(metric_name)) if hasattr(value, '__len__'): tqdm.write(f' {metric_name}:') tqdm.write(pprint.pformat(value)) elif is_float(value): if abs(value) > 1e-4 or value == 0: tqdm.write( f' {metric_name}:{padding} {value:.4f}' ) else: tqdm.write( f' {metric_name}:{padding} {value:.4e}' ) else: tqdm.write(f' {metric_name}:{padding} {value}')