Source code for workflow.ignite.handlers.model_checkpoint

import os
import torch
import numpy as np
import ignite
from ignite.engine import Events


[docs]class ModelCheckpoint: dirname = 'checkpoints' filename_prefix = 'model' def __init__(self, model_score_function=None, n_saved=1): self.model_checkpoint = ignite.handlers.ModelCheckpoint( dirname=self.dirname, filename_prefix=self.filename_prefix, score_function=model_score_function, n_saved=n_saved, require_empty=False, )
[docs] def attach(self, engine, *args, **kwargs): engine.add_event_handler( Events.EPOCH_COMPLETED, self.model_checkpoint, *args, **kwargs, )
[docs] @staticmethod def load( to_load, dirname=None, device=None, suffix=None, ): if dirname is None: dirname = ModelCheckpoint.dirname models = os.listdir(dirname) if suffix is None: suffixes = [ '_'.join( os.path.splitext(name)[0] .lstrip(ModelCheckpoint.filename_prefix) .split('_')[2:] ) for name in models ] suffix = suffixes[np.argmax([float(s.split('_')[-1]) for s in suffixes])] saved_checkpoint_state = torch.load( f'{dirname}/{ModelCheckpoint.filename_prefix}_checkpoint_{suffix}.pt', map_location=device, ) for name, module_or_optimizer in to_load.items(): module_or_optimizer.load_state_dict( saved_checkpoint_state[name] ) return suffix