rl_equation_solver.agent.base.BaseAgent

class BaseAgent(env, config=None, device='cpu')[source]

Bases: LossMixin

Agent with DQN target and policy networks

Parameters:
  • env (Object) – Environment instance. e.g. rl_equation_solver.env_linear_equation.Env()

  • config (dict | None) – Model configuration. If None then the default model configuration in rl_equation_solver.config will be used.

  • device (str) – Device to use for torch objects. e.g. ‘cpu’ or ‘cuda:0’

Methods

batch_states(states)

Convert states into a batch

choose_action(state[, training])

Choose action based on given state.

choose_optimal_action(state)

Choose action with max expected reward \(:= max_a Q(s, a)\)

choose_random_action()

Choose random action rather than the optimal action

compute_Q(batch)

Compute \(Q(s_t, a)\).

compute_V(batch)

Compute \(V(s_{t+1})\) for all next states.

compute_batch_loss()

Compute loss for batch using the stored memory.

compute_expected_Q(batch, next_state_values)

Compute the expected Q values

compute_loss(state_action_values, ...)

Compute Huber loss

convert_state(state_string)

Convert state string to appropriate representation.

get_env()

Get environment

huber_loss(x, y[, delta])

Huber loss.

init_config()

Initialize model configuration

init_optimizer()

Initialize optimizer

init_state()

Initialize state from the environment.

is_constant_complexity()

Check for constant loss over a long number of steps

l2_loss(x, y)

L2 Loss

load(env, model_file)

Load policy_network from model_file

optimize_model([loss])

Perform one step of the optimization (on the policy network).

predict(state_string)

Predict the solution from the given state_string.

save(output_file)

Save the policy_network

set_env(env)

Set the environment

smooth_l1_loss(x, y)

Smooth L1 Loss

step(state[, training])

Take next step from current state

terminate_msg()

Log message about solver termination

train(num_episodes[, eval])

Train the model for the given number of episodes.

update_config(config)

Update configuration

update_info(key, value)

Update history info with given value for the given key

update_networks()

Soft update of the target network's weights \(\theta^{'} \leftarrow \tau \theta + (1 - \tau) \theta^{'}\) policy_network.state_dict() returns the parameters of the policy network target_network.load_state_dict() loads these parameters into the target network.

Attributes

current_episode

Get current episode

device

Get device for training network

history

Get training history of policy_network

info

Get environment info

state_string

Get state string representation

steps_done

Get total number of steps done across all episodes

init_config()[source]

Initialize model configuration

update_config(config)[source]

Update configuration

abstract init_state()[source]

Initialize state from the environment. This can be a vector representation or graph representation

abstract convert_state(state_string)[source]

Convert state string to appropriate representation. This can be a vector representation or graph representation

abstract batch_states(states)[source]

Convert states into a batch

init_optimizer()[source]

Initialize optimizer

property steps_done

Get total number of steps done across all episodes

property current_episode

Get current episode

property device

Get device for training network

step(state, training=False)[source]

Take next step from current state

Parameters:
  • state (str) – State string representation

  • episode (int) – Episode number

  • training (str) – Whether the step is part of training or inference. Determines whether to update the history.

Returns:

  • action (Tensor) – Action taken. Represented as a pytorch tensor.

  • next_state (Tensor) – Next state after action. Represented as a pytorch tensor or GraphEmbedding.

  • done (bool) – Whether solution has been found or if state size conditions have been exceeded.

  • info (dict) – Dictionary with loss, reward, and state information

choose_optimal_action(state)[source]

Choose action with max expected reward \(:= max_a Q(s, a)\)

max(1) will return largest column value of each row. second column on max result is index of where max element was found so we pick action with the larger expected reward.

choose_action(state, training=False)[source]

Choose action based on given state. Either choose optimal action or random action depending on training step.

compute_loss(state_action_values, expected_state_action_values)[source]

Compute Huber loss

compute_batch_loss()[source]

Compute loss for batch using the stored memory.

update_info(key, value)[source]

Update history info with given value for the given key

choose_random_action()[source]

Choose random action rather than the optimal action

optimize_model(loss=None)[source]

Perform one step of the optimization (on the policy network).

compute_expected_Q(batch, next_state_values)[source]

Compute the expected Q values

compute_V(batch)[source]

Compute \(V(s_{t+1})\) for all next states. Expected values of actions for non_final_next_states are computed based on the “older” target_net; selecting their best reward with max(1)[0].

compute_Q(batch)[source]

Compute \(Q(s_t, a)\). These are the actions which would’ve been taken for each batch state according to policy_net

train(num_episodes, eval=False)[source]

Train the model for the given number of episodes.

The agent will perform a soft update of the Target Network’s weights, with the equation \(\tau \text{ policy_net_state_dict} + (1 - \tau) \text{ target_net_state_dict}\), this helps to make the Target Network’s weights converge to the Policy Network’s weights.

Parameters:
  • num_episodes (int) – Number of episodes to train for

  • eval (bool) – Whether to run in eval mode - without updating model weights

terminate_msg()[source]

Log message about solver termination

Parameters:

total_reward (list) – List of reward

update_networks()[source]

Soft update of the target network’s weights \(\theta^{'} \leftarrow \tau \theta + (1 - \tau) \theta^{'}\) policy_network.state_dict() returns the parameters of the policy network target_network.load_state_dict() loads these parameters into the target network.

property history

Get training history of policy_network

predict(state_string)[source]

Predict the solution from the given state_string.

is_constant_complexity()[source]

Check for constant loss over a long number of steps

save(output_file)[source]

Save the policy_network

classmethod load(env, model_file)[source]

Load policy_network from model_file

huber_loss(x, y, delta=1.0)

Huber loss. Huber loss, also known as Smooth Mean Absolute Error, is a loss function used in various machine learning and optimization problems, particularly in regression tasks. It combines the properties of both Mean Squared Error (MSE) and Mean Absolute Error (MAE) loss functions, providing a balance between the two.

\[ L(y, f(x)) = \begin{cases} \begin{split} \frac{1}{2} (y - f(x))^2, & \text{ if } |y - f(x)| \leq \delta \\ \delta |y - f(x)| - \frac{1}{2} \delta^2, & \text{ otherwise} \end{split} \end{cases} \]
l2_loss(x, y)

L2 Loss

smooth_l1_loss(x, y)

Smooth L1 Loss

property state_string

Get state string representation

property info

Get environment info

get_env()[source]

Get environment

set_env(env)[source]

Set the environment