Source code for rl_equation_solver.utilities.history

"""History mixin class"""
import numpy as np
import logging


logger = logging.getLogger(__name__)


[docs]class HistoryMixin: """Collection of history method""" def __init__(self): self._history = {} self.current_episode = 0 self.steps_done = 0 @property def history(self): """Get training history of policy_network""" return self._history @property def avg_history(self): """Get history averaged over each episode""" out = {k: [] for k in self.history[0] if k != 'state'} for _, series in self.history.items(): for k in out: out[k].append(np.nanmean(series[k])) return out @history.setter def history(self, value): """Set training history of policy_network""" self._history = value
[docs] def append_history(self, entry): """Append latest step for training history of policy_network""" episode = entry['ep'] if episode not in self._history: self._history[episode] = {k: [] for k in entry.keys()} for k, v in entry.items(): self._history[episode][k].append(v)
[docs] def update_history(self, key, value): """Update latest step for training history of policy_network""" episode = list(self.history.keys())[-1] self._history[episode][key][-1] = value
[docs] def log_info(self): """Write info to logger""" out = self.history[list(self.history.keys())[-1]] out = {k: v[-1] for k, v, in out.items()} out['reward'] = '{:.3e}'.format(out['reward']) out['loss'] = '{:.3e}'.format(out['loss']) logger.info(f'\n{out}')
[docs] def reset_history(self): """Clear history""" self._history = {} self.current_episode = 0 self.steps_done = 0