rl_equation_solver.agent.gcn.Agent
- class Agent(env, config=None, device='cpu')[source]
Bases:
BaseAgent
Agent with GCN target and policy networks
- Parameters:
env (Object) – Environment instance. e.g. rl_equation_solver.env_linear_equation.Env()
config (dict | None) – Model configuration. If None then the default model configuration in rl_equation_solver.config will be used.
device (str) – Device to use for torch objects. e.g. ‘cpu’ or ‘cuda:0’
Methods
batch_states
(states, device)Batch agent states
choose_action
(state[, training])Choose action based on given state.
choose_optimal_action
(state)Choose action with max expected reward \(:= max_a Q(s, a)\)
Choose random action rather than the optimal action
compute_Q
(batch)Compute \(Q(s_t, a)\).
compute_V
(batch)Compute \(V(s_{t+1})\) for all next states.
Compute loss for batch using the stored memory.
compute_expected_Q
(batch, next_state_values)Compute the expected Q values
compute_loss
(state_action_values, ...)Compute Huber loss
convert_state
(state)Convert state string to graph representation
get_env
()Get environment
huber_loss
(x, y[, delta])Huber loss.
Initialize model configuration
Initialize optimizer
Initialize state as a graph
Check for constant loss over a long number of steps
l2_loss
(x, y)L2 Loss
load
(env, model_file)Load policy_network from model_file
optimize_model
([loss])Perform one step of the optimization (on the policy network).
predict
(state_string)Predict the solution from the given state_string.
save
(output_file)Save the policy_network
set_env
(env)Set the environment
smooth_l1_loss
(x, y)Smooth L1 Loss
step
(state[, training])Take next step from current state
Log message about solver termination
train
(num_episodes[, eval])Train the model for the given number of episodes.
update_config
(config)Update configuration
update_info
(key, value)Update history info with given value for the given key
Soft update of the target network's weights \(\theta^{'} \leftarrow \tau \theta + (1 - \tau) \theta^{'}\) policy_network.state_dict() returns the parameters of the policy network target_network.load_state_dict() loads these parameters into the target network.
Attributes
Get current episode
Get device for training network
Get training history of policy_network
Get environment info
Get state string representation
Get total number of steps done across all episodes
- choose_action(state, training=False)
Choose action based on given state. Either choose optimal action or random action depending on training step.
- choose_optimal_action(state)
Choose action with max expected reward \(:= max_a Q(s, a)\)
max(1) will return largest column value of each row. second column on max result is index of where max element was found so we pick action with the larger expected reward.
- choose_random_action()
Choose random action rather than the optimal action
- compute_Q(batch)
Compute \(Q(s_t, a)\). These are the actions which would’ve been taken for each batch state according to policy_net
- compute_V(batch)
Compute \(V(s_{t+1})\) for all next states. Expected values of actions for non_final_next_states are computed based on the “older” target_net; selecting their best reward with max(1)[0].
- compute_batch_loss()
Compute loss for batch using the stored memory.
- compute_expected_Q(batch, next_state_values)
Compute the expected Q values
- compute_loss(state_action_values, expected_state_action_values)
Compute Huber loss
- property current_episode
Get current episode
- property device
Get device for training network
- get_env()
Get environment
- property history
Get training history of policy_network
- huber_loss(x, y, delta=1.0)
Huber loss. Huber loss, also known as Smooth Mean Absolute Error, is a loss function used in various machine learning and optimization problems, particularly in regression tasks. It combines the properties of both Mean Squared Error (MSE) and Mean Absolute Error (MAE) loss functions, providing a balance between the two.
\[ L(y, f(x)) = \begin{cases} \begin{split} \frac{1}{2} (y - f(x))^2, & \text{ if } |y - f(x)| \leq \delta \\ \delta |y - f(x)| - \frac{1}{2} \delta^2, & \text{ otherwise} \end{split} \end{cases} \]
- property info
Get environment info
- init_config()
Initialize model configuration
- init_optimizer()
Initialize optimizer
- is_constant_complexity()
Check for constant loss over a long number of steps
- l2_loss(x, y)
L2 Loss
- classmethod load(env, model_file)
Load policy_network from model_file
- optimize_model(loss=None)
Perform one step of the optimization (on the policy network).
- predict(state_string)
Predict the solution from the given state_string.
- save(output_file)
Save the policy_network
- set_env(env)
Set the environment
- smooth_l1_loss(x, y)
Smooth L1 Loss
- property state_string
Get state string representation
- step(state, training=False)
Take next step from current state
- Parameters:
state (str) – State string representation
episode (int) – Episode number
training (str) – Whether the step is part of training or inference. Determines whether to update the history.
- Returns:
action (Tensor) – Action taken. Represented as a pytorch tensor.
next_state (Tensor) – Next state after action. Represented as a pytorch tensor or GraphEmbedding.
done (bool) – Whether solution has been found or if state size conditions have been exceeded.
info (dict) – Dictionary with loss, reward, and state information
- property steps_done
Get total number of steps done across all episodes
- terminate_msg()
Log message about solver termination
- Parameters:
total_reward (list) – List of reward
- train(num_episodes, eval=False)
Train the model for the given number of episodes.
The agent will perform a soft update of the Target Network’s weights, with the equation \(\tau \text{ policy_net_state_dict} + (1 - \tau) \text{ target_net_state_dict}\), this helps to make the Target Network’s weights converge to the Policy Network’s weights.
- Parameters:
num_episodes (int) – Number of episodes to train for
eval (bool) – Whether to run in eval mode - without updating model weights
- update_config(config)
Update configuration
- update_info(key, value)
Update history info with given value for the given key
- update_networks()
Soft update of the target network’s weights \(\theta^{'} \leftarrow \tau \theta + (1 - \tau) \theta^{'}\) policy_network.state_dict() returns the parameters of the policy network target_network.load_state_dict() loads these parameters into the target network.