rl_equation_solver.agent.base.BaseAgent
- class BaseAgent(env, config=None, device='cpu')[source]
Bases:
LossMixin
Agent with DQN target and policy networks
- Parameters:
env (Object) – Environment instance. e.g. rl_equation_solver.env_linear_equation.Env()
config (dict | None) – Model configuration. If None then the default model configuration in rl_equation_solver.config will be used.
device (str) – Device to use for torch objects. e.g. ‘cpu’ or ‘cuda:0’
Methods
batch_states
(states)Convert states into a batch
choose_action
(state[, training])Choose action based on given state.
choose_optimal_action
(state)Choose action with max expected reward \(:= max_a Q(s, a)\)
Choose random action rather than the optimal action
compute_Q
(batch)Compute \(Q(s_t, a)\).
compute_V
(batch)Compute \(V(s_{t+1})\) for all next states.
Compute loss for batch using the stored memory.
compute_expected_Q
(batch, next_state_values)Compute the expected Q values
compute_loss
(state_action_values, ...)Compute Huber loss
convert_state
(state_string)Convert state string to appropriate representation.
get_env
()Get environment
huber_loss
(x, y[, delta])Huber loss.
Initialize model configuration
Initialize optimizer
Initialize state from the environment.
Check for constant loss over a long number of steps
l2_loss
(x, y)L2 Loss
load
(env, model_file)Load policy_network from model_file
optimize_model
([loss])Perform one step of the optimization (on the policy network).
predict
(state_string)Predict the solution from the given state_string.
save
(output_file)Save the policy_network
set_env
(env)Set the environment
smooth_l1_loss
(x, y)Smooth L1 Loss
step
(state[, training])Take next step from current state
Log message about solver termination
train
(num_episodes[, eval])Train the model for the given number of episodes.
update_config
(config)Update configuration
update_info
(key, value)Update history info with given value for the given key
Soft update of the target network's weights \(\theta^{'} \leftarrow \tau \theta + (1 - \tau) \theta^{'}\) policy_network.state_dict() returns the parameters of the policy network target_network.load_state_dict() loads these parameters into the target network.
Attributes
Get current episode
Get device for training network
Get training history of policy_network
Get environment info
Get state string representation
Get total number of steps done across all episodes
- abstract init_state()[source]
Initialize state from the environment. This can be a vector representation or graph representation
- abstract convert_state(state_string)[source]
Convert state string to appropriate representation. This can be a vector representation or graph representation
- property steps_done
Get total number of steps done across all episodes
- property current_episode
Get current episode
- property device
Get device for training network
- step(state, training=False)[source]
Take next step from current state
- Parameters:
state (str) – State string representation
episode (int) – Episode number
training (str) – Whether the step is part of training or inference. Determines whether to update the history.
- Returns:
action (Tensor) – Action taken. Represented as a pytorch tensor.
next_state (Tensor) – Next state after action. Represented as a pytorch tensor or GraphEmbedding.
done (bool) – Whether solution has been found or if state size conditions have been exceeded.
info (dict) – Dictionary with loss, reward, and state information
- choose_optimal_action(state)[source]
Choose action with max expected reward \(:= max_a Q(s, a)\)
max(1) will return largest column value of each row. second column on max result is index of where max element was found so we pick action with the larger expected reward.
- choose_action(state, training=False)[source]
Choose action based on given state. Either choose optimal action or random action depending on training step.
- compute_V(batch)[source]
Compute \(V(s_{t+1})\) for all next states. Expected values of actions for non_final_next_states are computed based on the “older” target_net; selecting their best reward with max(1)[0].
- compute_Q(batch)[source]
Compute \(Q(s_t, a)\). These are the actions which would’ve been taken for each batch state according to policy_net
- train(num_episodes, eval=False)[source]
Train the model for the given number of episodes.
The agent will perform a soft update of the Target Network’s weights, with the equation \(\tau \text{ policy_net_state_dict} + (1 - \tau) \text{ target_net_state_dict}\), this helps to make the Target Network’s weights converge to the Policy Network’s weights.
- Parameters:
num_episodes (int) – Number of episodes to train for
eval (bool) – Whether to run in eval mode - without updating model weights
- terminate_msg()[source]
Log message about solver termination
- Parameters:
total_reward (list) – List of reward
- update_networks()[source]
Soft update of the target network’s weights \(\theta^{'} \leftarrow \tau \theta + (1 - \tau) \theta^{'}\) policy_network.state_dict() returns the parameters of the policy network target_network.load_state_dict() loads these parameters into the target network.
- property history
Get training history of policy_network
- huber_loss(x, y, delta=1.0)
Huber loss. Huber loss, also known as Smooth Mean Absolute Error, is a loss function used in various machine learning and optimization problems, particularly in regression tasks. It combines the properties of both Mean Squared Error (MSE) and Mean Absolute Error (MAE) loss functions, providing a balance between the two.
\[ L(y, f(x)) = \begin{cases} \begin{split} \frac{1}{2} (y - f(x))^2, & \text{ if } |y - f(x)| \leq \delta \\ \delta |y - f(x)| - \frac{1}{2} \delta^2, & \text{ otherwise} \end{split} \end{cases} \]
- l2_loss(x, y)
L2 Loss
- smooth_l1_loss(x, y)
Smooth L1 Loss
- property state_string
Get state string representation
- property info
Get environment info