DeepLearningExamples/TensorFlow/Classification/ConvNets/triton/metrics.py
2021-04-20 13:50:41 +02:00

18 lines
706 B
Python

from typing import Any, Dict, List, Optional
import numpy as np
from deployment_toolkit.core import BaseMetricsCalculator
class MetricsCalculator(BaseMetricsCalculator):
def __init__(self, output_used_for_metrics: str = "classes"):
self._output_used_for_metrics = output_used_for_metrics
def calc(self, *, y_pred: Dict[str, np.ndarray], y_real: Optional[Dict[str, np.ndarray]], **_) -> Dict[str, float]:
y_true = y_real[self._output_used_for_metrics]
y_pred = y_pred[self._output_used_for_metrics]
y_true = np.squeeze(y_true)
y_pred = np.squeeze(y_pred)
assert y_true.shape == y_pred.shape
return {"accuracy": (y_true == y_pred).mean()}