Source code for rl_equation_solver.agent.networks

"""Networks for agent policies"""
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv


[docs]class DQN(nn.Module): """Simple MLP network.""" def __init__(self, n_observations, n_actions, hidden_size): """ Parameters ---------- n_observations: int observation/state size of the environment n_actions : int number of discrete actions available in the environment hidden_size : int size of hidden layers """ super().__init__() self.layer1 = nn.Linear(n_observations, hidden_size) self.layer2 = nn.Linear(hidden_size, hidden_size) self.layer3 = nn.Linear(hidden_size, n_actions)
[docs] def forward(self, x): """ Forward pass for given state x """ x = F.relu(self.layer1(x)) x = F.relu(self.layer2(x)) return self.layer3(x)
[docs]class GCN(nn.Module): """Graph Convolution Network""" def __init__(self, n_observations, n_actions, hidden_size, dropout=0.1): """ Parameters ---------- n_observations: int observation/state size of the environment n_actions : int number of discrete actions available in the environment hidden_size : int size of hidden layers dropout : float dropout rate """ super().__init__() self.n_observations = n_observations self.n_actions = n_actions self.hidden_size = hidden_size self.layer1 = GCNConv(n_observations, hidden_size, normalize=True, cached=True) self.layer2 = GCNConv(hidden_size, n_actions, normalize=True, cached=True) self.dropout = dropout def _forward(self, graph): """Forward pass for a given state graph""" x = graph.x.T edge_index = graph.adj x = F.dropout(x, p=self.dropout, training=self.training) x = F.relu(self.layer1(x, edge_index)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.layer2(x, edge_index) x = torch.matmul(graph.onehot_values, x) return x
[docs] def forward(self, graph): """Forward pass for a given state graph or tuple of graphs""" if isinstance(graph, (tuple, list)): return torch.cat([self._forward(G) for G in graph]) else: return self._forward(graph)
[docs]class LSTM(nn.Module): """LSTM network""" def __init__(self, n_observations, n_actions, hidden_size, n_features): super().__init__() self.n_actions = n_actions self.n_features = n_features self.n_observations = n_observations self.hidden_size = hidden_size self.lstm = nn.LSTM(input_size=n_observations, hidden_size=hidden_size, num_layers=n_features, batch_first=True) self.linear = nn.Linear(hidden_size, n_actions)
[docs] def forward(self, x): """Forward pass on state x""" x, _ = self.lstm(x) x = self.linear(x) return x