Source code for workflow.ignite.metrics.matthews_correlation_coefficient

import torch

from ignite.metrics import (
    ConfusionMatrix,
    MetricsLambda,
)


def _matthews_correlation_coefficient(confusion_matrix):
    [[tp, fp], [fn, tn]] = confusion_matrix.float()

    if all([
        0 < tp + fp,
        0 < tp + fn,
        0 < tn + fp,
        0 < tn + fn,
    ]):
        return (
            (tp * tn - fp * fn)
                / torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
        )
    else:
        return torch.tensor(0.)

[docs]def MatthewsCorrelationCoefficient(output_transform): return MetricsLambda( _matthews_correlation_coefficient, ConfusionMatrix(num_classes=2, output_transform=output_transform) )