from tqdm import tqdm
from ignite.engine import Events
import pprint
from workflow.torch.is_float import is_float
[docs]class MetricsLogger:
def __init__(self, name):
self.name = name
[docs] def attach(self, engine, metric_names):
engine.add_event_handler(
Events.EPOCH_COMPLETED,
lambda engine: self._print(engine, metric_names)
)
def _print(self, engine, metric_names):
if len(engine.state.metrics) >= 1:
tqdm.write(f'{self.name}:')
target_len = max(map(len, metric_names))
for metric_name in metric_names:
value = engine.state.metrics.get(metric_name, None)
if value is not None:
padding = ' ' * (target_len - len(metric_name))
if hasattr(value, '__len__'):
tqdm.write(f' {metric_name}:')
tqdm.write(pprint.pformat(value))
elif is_float(value):
if abs(value) > 1e-4 or value == 0:
tqdm.write(
f' {metric_name}:{padding} {value:.4f}'
)
else:
tqdm.write(
f' {metric_name}:{padding} {value:.4e}'
)
else:
tqdm.write(f' {metric_name}:{padding} {value}')