Source code for workflow.ignite.decorators.step

from functools import wraps


[docs]def step(optimizer, n_batches_per_step=1): def decorator(backward_fn): @wraps(backward_fn) def batch_fn(engine, batch): result = backward_fn(engine, batch) if engine.state.iteration % n_batches_per_step == 0: for param_group in optimizer.param_groups: for parameters in param_group['params']: if parameters.grad is not None: parameters.grad.div_(n_batches_per_step) optimizer.step() optimizer.zero_grad() return result return batch_fn return decorator