"""DQN module"""
import math
import random
from itertools import count
import torch
from torch import optim
import logging
from abc import abstractmethod
import numpy as np
from rl_equation_solver.config import DefaultConfig
from rl_equation_solver.utilities.loss import LossMixin
from rl_equation_solver.utilities.utilities import ReplayMemory
logger = logging.getLogger(__name__)
# pylint: disable=not-callable
[docs]class BaseAgent(LossMixin):
"""Agent with DQN target and policy networks"""
def __init__(self, env, config=None, device='cpu'):
env : Object
Environment instance.
e.g. rl_equation_solver.env_linear_equation.Env()
config : dict | None
Model configuration. If None then the default model configuration
in rl_equation_solver.config will be used.
device : str
Device to use for torch objects. e.g. 'cpu' or 'cuda:0'
self.env = env
self.n_actions = env.n_actions
self.n_observations = env.n_obs
self.memory = None
self.policy_network = None
self.target_network = None
self.optimizer = None
self._history = {}
self._device = device
self.config = config
# Configuration properties
self.batch_size = None
self.gamma = None
self.eps_start = None
self.eps_end = None
self.hidden_size = None
self.eps_decay = None
self.eps_decay_steps = None
self.tau = None
self.learning_rate = None
self.memory_cap = None
self.reset_steps = None
self.vec_dim = None
self.feature_num = None
self.grad_clip = None
self.memory = ReplayMemory(self.memory_cap)
[docs] def init_config(self):
"""Initialize model configuration"""
config = DefaultConfig
if self.config is not None:
for key, val in config.items():
if hasattr(self, key):
setattr(self, key, val)
[docs] def update_config(self, config):
"""Update configuration"""
self.config = config
[docs] @abstractmethod
def init_state(self):
"""Initialize state from the environment. This can be a vector
representation or graph representation"""
[docs] @abstractmethod
def convert_state(self, state_string):
"""Convert state string to appropriate representation. This can be a
vector representation or graph representation"""
[docs] @abstractmethod
def batch_states(self, states):
"""Convert states into a batch"""
[docs] def init_optimizer(self):
"""Initialize optimizer"""
self.optimizer = optim.AdamW(self.policy_network.parameters(),
lr=self.learning_rate, amsgrad=True)
def _get_eps_decay(self):
"""Get epsilon decay for current number of steps"""
decay = 0
if self.eps_decay is None:
decay = self.eps_start - self.eps_end
decay *= math.exp(-1. * self.steps_done / self.eps_decay_steps)
return decay
def steps_done(self):
"""Get total number of steps done across all episodes"""
return self.env.steps_done
def current_episode(self):
"""Get current episode"""
return self.env.current_episode
def current_episode(self, value):
"""Set current episode"""
self.env.current_episode = value
def device(self):
"""Get device for training network"""
if self._device is None:
if torch.cuda.is_available():
self._device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
self._device = torch.device('mps:0')
self._device = torch.device('cpu')
elif isinstance(self._device, str):
self._device = torch.device(self._device)
return self._device
[docs] def step(self, state, training=False):
"""Take next step from current state
state : str
State string representation
episode : int
Episode number
training : str
Whether the step is part of training or inference. Determines
whether to update the history.
action : Tensor
Action taken. Represented as a pytorch tensor.
next_state : Tensor
Next state after action. Represented as a pytorch tensor or
done : bool
Whether solution has been found or if state size conditions have
been exceeded.
info : dict
Dictionary with loss, reward, and state information
action = self.choose_action(state, training=training)
_, _, done, info = self.env.step(action.item())
if done:
next_state = None
next_state = self.convert_state(self.state_string)
return action, next_state, done, info
[docs] def choose_optimal_action(self, state):
Choose action with max expected reward :math:`:= max_a Q(s, a)`
max(1) will return largest column value of each row. second column on
max result is index of where max element was found so we pick action
with the larger expected reward.
with torch.no_grad():
return self.policy_network(state).max(1)[1].view(1, 1)
[docs] def choose_action(self, state, training=False):
Choose action based on given state. Either choose optimal action or
random action depending on training step.
random_float = random.random()
epsilon_threshold = self.eps_end + self._get_eps_decay()
if not training:
epsilon_threshold = self.eps_end
if random_float > epsilon_threshold:
return self.choose_optimal_action(state)
return self.choose_random_action()
[docs] def compute_loss(self, state_action_values, expected_state_action_values):
"""Compute Huber loss"""
loss = self.smooth_l1_loss(state_action_values,
return loss
[docs] def compute_batch_loss(self):
"""Compute loss for batch using the stored memory."""
if len(self.memory) < self.batch_size:
return None
transition = self.memory.sample(self.batch_size)
batch = self.batch_states(transition, device=self.device)
state_action_values = self.compute_Q(batch)
next_state_values = self.compute_V(batch)
expected_state_action_values = self.compute_expected_Q(
batch, next_state_values)
loss = self.compute_loss(state_action_values,
self.update_info('loss', loss.item())
self.env.update_history('loss', loss.item())
return loss
[docs] def update_info(self, key, value):
"""Update history info with given value for the given key"""
self.info[key] = value
[docs] def choose_random_action(self):
"""Choose random action rather than the optimal action"""
return torch.tensor([[self.env.action_space.sample()]],
device=self.device, dtype=torch.long)
[docs] def optimize_model(self, loss=None):
Perform one step of the optimization (on the policy network).
if loss is None:
# optimize the model
# In-place gradient clipping
[docs] def compute_expected_Q(self, batch, next_state_values):
Compute the expected Q values
return batch.reward_batch + (self.gamma * next_state_values)
[docs] def compute_V(self, batch):
Compute :math:`V(s_{t+1})` for all next states. Expected values of
actions for non_final_next_states are computed based on the "older"
target_net; selecting their best reward with max(1)[0].
next_state_values = torch.zeros(self.batch_size, device=self.device)
with torch.no_grad():
next_state_values[batch.non_final_mask] = \
return next_state_values
[docs] def compute_Q(self, batch):
Compute :math:`Q(s_t, a)`. These are the actions which would've been
taken for each batch state according to policy_net
return self.policy_network(batch.state_batch) \
.gather(1, batch.action_batch)
[docs] def train(self, num_episodes, eval=False):
r"""Train the model for the given number of episodes.
The agent will perform a soft update of the Target Network's weights,
with the equation :math:`\tau \text{ policy_net_state_dict} +
(1 - \tau) \text{ target_net_state_dict}`, this helps to make the
Target Network's weights converge to the Policy Network's weights.
num_episodes : int
Number of episodes to train for
eval : bool
Whether to run in eval mode - without updating model weights
logger.info(f'Running training routine for {num_episodes} episodes in '
f'eval={eval} mode.')
training = bool(not eval)
if eval:
self.history = {}
self.current_episode = 0
self.eps_decay = 0
start = self.current_episode
end = start + num_episodes
for _ in range(start, end):
state = self.init_state()
for _ in count():
action, next_state, done, info = self.step(state,
reward = torch.tensor([info['reward']], device=self.device)
self.memory.push(state, action, next_state, reward)
loss = self.compute_batch_loss()
if not done:
if not eval:
if done:
state = next_state
[docs] def terminate_msg(self):
"""Log message about solver termination
total_reward : list
List of reward
current_episode = list(self.history.keys())[-1]
total_reward = np.nansum(self.history[current_episode]['reward'])
mean_loss = np.nanmean(self.history[current_episode]['loss'])
msg = (f"\nSolver terminated after {self.env.loop_step + 1} steps: "
f"total_reward = {total_reward:.3e}, "
f"mean_loss = {mean_loss:.3e}, "
f"state = {self.state_string}")
[docs] def update_networks(self):
Soft update of the target network's weights :math:`\theta^{'}
\leftarrow \tau \theta + (1 - \tau) \theta^{'}`
policy_network.state_dict() returns the parameters of the policy
network target_network.load_state_dict() loads these parameters into
the target network.
target_net_state_dict = self.target_network.state_dict()
policy_net_state_dict = self.policy_network.state_dict()
for key in policy_net_state_dict:
value = policy_net_state_dict[key] * self.tau
value += target_net_state_dict[key] * (1 - self.tau)
target_net_state_dict[key] = value
def history(self):
"""Get training history of policy_network"""
return self.env.history
def history(self, value):
"""Set training history of policy_network"""
self.env.history = value
[docs] def predict(self, state_string):
Predict the solution from the given state_string.
state = self.convert_state(state_string)
done = False
t = 0
while not done:
_, _, _, done = self.step(state, training=False)
complexity = self.env.expression_complexity(self.env.state_string)
t += 1
logger.info(f"Solver terminated after {t + 1} steps. Final "
f"state = {self.env.state_string} with complexity = "
# pylint: disable=invalid-unary-operand-type
[docs] def is_constant_complexity(self):
"""Check for constant loss over a long number of steps"""
current_episode = list(self.history.keys())[-1]
complexities = self.history[current_episode]['complexity']
check = (len(complexities) >= self.reset_steps
and len(set(complexities[-self.reset_steps:])) <= 1)
if check:
logger.info('Complexity has been constant '
f'({list(complexities)[-1]}) for {self.reset_steps} '
'steps. Reseting.')
return check
[docs] def save(self, output_file):
"""Save the policy_network"""
torch.save(self.policy_network.state_dict(), output_file)
logger.info(f'Saved policy_network to {output_file}')
[docs] @classmethod
def load(cls, env, model_file):
"""Load policy_network from model_file"""
agent = cls(env)
logger.info(f'Loaded policy_network from {model_file}')
return agent
def state_string(self):
"""Get state string representation"""
return self.env.state_string
def state_string(self, value):
"""Set state string representation"""
self.env.state_string = value
def info(self):
"""Get environment info"""
return self.env.info
def info(self, value):
"""Set environment info"""
self.env.info = value
[docs] def get_env(self):
"""Get environment"""
return self.env
[docs] def set_env(self, env):
"""Set the environment"""
self.env = env