''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import collections
BATCH_COUNT = "batch_count"
NUM_SAMPLES = "num_samples"
BATCH_SIZE = "*batch_size"
[docs]
class AverageMeter:
"""
Computes and stores the average and current value.
"""
def __init__(self):
self.reset()
[docs]
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
[docs]
def update(self, val, n=1):
"""Update current value, total sum, and average."""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
[docs]
class AverageMeterCollection:
"""
This is a class called AverageMeterCollection that calculates and stores the average metrics for a collection of meters.
"""
def __init__(self):
self._batch_count = 0
self.n = 0
self._meters = collections.defaultdict(AverageMeter)
[docs]
def update(self, metrics, n=1):
"""Does one batch of updates for the provided metrics."""
self._batch_count += 1
self.n += n
for metric, value in metrics.items():
self._meters[metric].update(value, n=n)
[docs]
def summary(self):
"""Returns a dict of average and most recent values for each metric."""
stats = {BATCH_COUNT: self._batch_count, NUM_SAMPLES: self.n}
for metric, meter in self._meters.items():
stats[str(metric)] = meter.avg
stats["last_" + str(metric)] = meter.val
return stats