"""Collection of useful functions"""
import numpy as np
import random
import scipy.sparse as sp
from collections import namedtuple, deque
import torch
from torch_geometric.utils.convert import from_networkx
import networkx as nx
from networkx.readwrite import json_graph
from networkx.drawing.nx_pydot import graphviz_layout
from rl_equation_solver.utilities.operators import fraction
Experience = namedtuple('Experience',
('state', 'action', 'next_state', 'reward'))
[docs]class ReplayMemory:
"""Stores the Experience Replay buffer"""
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
[docs] def push(self, *args):
"""Save the Experience into memory"""
self.memory.append(Experience(*args))
[docs] def sample(self, batch_size):
"""select a random batch of Experience for training"""
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
[docs]class Batch:
"""Graph Embedding or state vector Batch"""
def __init__(self):
"""Initialize the batch"""
self.experience = None
self.non_final_mask = None
self.non_final_next_states = None
self.non_final_next_states = None
self.state_batch = None
self.action_batch = None
self.reward_batch = None
[docs] @classmethod
def __call__(cls, states, device):
"""Batch states for given set of states and send to device. States
can be either instances of GraphEmbedding or np.ndarray"""
batch = cls()
batch.experience = Experience(*zip(*states))
batch.non_final_mask = torch.tensor(
tuple(map(lambda s: s is not None, batch.experience.next_state)),
device=device, dtype=torch.bool)
batch.non_final_next_states = [s for s in batch.experience.next_state
if s is not None]
batch.state_batch = [s for s in batch.experience.state
if s is not None]
batch.action_batch = torch.cat(batch.experience.action)
batch.reward_batch = torch.cat(batch.experience.reward)
return batch
[docs]class Id:
"""A helper class for autoincrementing node numbers."""
counter = -1
[docs] @classmethod
def get(cls):
"""
Get the node number
"""
cls.counter += 1
return cls.counter
[docs] @classmethod
def reset(cls):
"""Reset counter"""
cls.counter = -1
[docs]class Node:
"""Represents a single operation or atomic argument."""
def __init__(self, label, expr_id):
self.id = expr_id
self.name = label
def __repr__(self):
return self.name
[docs]class VectorEmbedding:
"""Vector embedding class for embedding feature vector in vector of
fixed size"""
def __init__(self, vector, n_observations, device):
self.vector = pad_array(vector, n_observations)
self.vector = torch.tensor(self.vector, device=device,
dtype=torch.float32)
[docs]class GraphEmbedding:
"""Graph embedding class for embedding node features in matrix of fixed
sizes"""
def __init__(self, graph, n_observations, n_features, device):
G = from_networkx(graph)
self._x = G.x.to(device)
self.adj = G.edge_index.to(device)
self._x, self.onehot_values = encode_onehot(np.array(self._x.cpu()))
self.onehot_values = np.array(list(self.onehot_values.keys()))
self.onehot_values = pad_array(self.onehot_values, n_features)
self.x = np.zeros((n_observations, n_features))
# embed in larger constant size matricies
max_i = min(self._x.shape[0], n_observations)
max_j = min(self._x.shape[1], n_features)
self.x[:max_i, :max_j] = self._x[:max_i:, :max_j]
self.x = torch.tensor(self.x, device=device, dtype=torch.float32)
self.onehot_values = torch.tensor(self.onehot_values, device=device,
dtype=torch.float32).unsqueeze(0)
[docs]def graph_walk(parent, expr, node_list, link_list):
"""
Walk over the expression tree recursively creating nodes and links.
Parameters
----------
parent : Node
Parent node
expr : str
State string
node_list : list
List of node dictionaries with 'id' and 'name' keys
link_list : list
List of link dictionaries with 'source' and 'target' keys
"""
if parent.name == 'Root':
Id.reset()
if expr.is_Atom:
node = Node(str(expr), Id.get())
node_list.append({"id": node.id, "name": node.name})
link_list.append({"source": parent.id, "target": node.id})
else:
node = Node(str(type(expr).__name__), Id.get())
node_list.append({"id": node.id, "name": node.name})
link_list.append({"source": parent.id, "target": node.id})
for arg in expr.args:
graph_walk(node, arg, node_list, link_list)
[docs]def pad_array(arr, length):
"""
Pad array with zeros according the given length
"""
max_i = min((length, len(arr)))
padded_arr = np.zeros(length)
padded_arr[:max_i] = arr[:max_i]
if len(arr) < length:
return padded_arr
else:
return arr
[docs]def to_vec(expr, feature_dict, state_dim=4096):
"""
Get state vector for given expression
Parameters
----------
expr : str
State string representation
feature_dict : dict
Dictionary mapping feature names to values
state_dim : int
Max length of state vector
Returns
-------
np.ndarray
State vector array
"""
graph = get_json_graph(expr)
node_features = get_node_features(graph, feature_dict)
node_features = pad_array(node_features, int(0.25 * state_dim))
edge_vector = nx.to_numpy_array(graph).flatten()
edge_vector = pad_array(edge_vector, int(0.75 * state_dim))
state_vec = np.concatenate([node_features, edge_vector], dtype=np.float32)
return state_vec
[docs]def get_json_graph(expr):
"""
Make a graph plot of the internal representation of SymPy expression. Don't
add meta data yet.
Parameters
----------
expr : str
State string representation
Returns
-------
networkx.Graph
"""
node_list = []
link_list = []
graph_walk(Node("Root", -1), expr, node_list, link_list)
# Create the graph from the lists of nodes and links:
graph_json = {"nodes": node_list, "links": link_list}
node_labels = {node['id']: node['name'] for node
in graph_json['nodes']}
for n in graph_json['nodes']:
del n['name']
graph = json_graph.node_link_graph(graph_json, directed=True,
multigraph=False)
for node in graph.nodes:
graph.nodes[node]['name'] = node_labels.get(node, 'Root')
graph.remove_node(-1)
return graph
[docs]def to_graph(expr, feature_dict):
"""
Make a graph plot of the internal representation of SymPy expression.
Parameters
----------
expr : str
State string representation
feature_dict : dict
Dictionary mapping feature names to values
Returns
-------
networkx.Graph
"""
graph = get_json_graph(expr)
node_features = get_node_features(graph, feature_dict)
for i, node in enumerate(list(graph.nodes)):
graph.nodes[node]['x'] = node_features[i]
return graph
[docs]def parse_node_features(node_features, feature_dict):
"""Parse node features. Includes string to fraction parsing"""
parsed_features = []
for key in node_features:
if key in feature_dict:
parsed_features.append(int(feature_dict[key]))
else:
parsed_features.append(fraction(key))
return parsed_features
[docs]def get_node_features(graph, feature_dict):
"""Get node features from feature dictionary. e.g. we can map the
operations and terms to integeters: {add: 0, sub: 1, .. }"""
node_labels = get_node_labels(graph)
node_features = list(node_labels.values())
node_features = np.array(parse_node_features(node_features, feature_dict))
return node_features
[docs]def get_node_labels(graph):
"""Get node labels from graph. Must be stored as node attributes as
graph.nodes[index]['name']. Includes None for nodes with no name
Parameters
----------
graph : networkx.graph
Networkx graph object with node['name'] attributes
"""
node_labels = {k: graph.nodes[k]['name'] for k in graph.nodes}
return node_labels
[docs]def plot_state_as_graph(expr):
"""
Make a graph plot of the internal representation of SymPy expression.
"""
graph = to_graph(expr)
labels = get_node_labels(graph)
pos = graphviz_layout(graph, prog="dot")
nx.draw(graph.to_directed(), pos, labels=labels,
node_shape="s", node_color="none",
bbox={'facecolor': 'skyblue', 'edgecolor': 'black',
'boxstyle': 'round,pad=0.2'})
[docs]def normalize(mx):
"""Row-normalize sparse matrix"""
rowsum = np.array(mx.sum(1))
r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
mx = r_mat_inv.dot(mx)
return mx
[docs]def sparse_mx_to_torch_sparse_tensor(sparse_mx):
"""Convert a scipy sparse matrix to a torch sparse tensor."""
sparse_mx = sparse_mx.tocoo().astype(np.float32)
indices = torch.from_numpy(
np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
values = torch.from_numpy(sparse_mx.data)
shape = torch.Size(sparse_mx.shape)
return torch.sparse.FloatTensor(indices, values, shape)
[docs]def build_adjacency_matrix_custom(graph):
"""Build adjacency matrix from graph edges and labels"""
edges = np.array(graph.edges)
labels = np.array(graph.nodes)
adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
shape=(labels.shape[0], labels.shape[0]),
dtype=np.float32)
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
adj = normalize(adj + sp.eye(adj.shape[0]))
adj = sparse_mx_to_torch_sparse_tensor(adj)
return adj
[docs]def build_adjacency_matrix(graph):
"""Build adjacency matrix from graph edges and labels"""
return nx.adjacency_matrix(graph)
[docs]def encode_onehot(labels):
"""Onehot encoding"""
classes = set(labels)
classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
enumerate(classes)}
labels_onehot = np.array(list(map(classes_dict.get, labels)),
dtype=np.int32)
return labels_onehot, classes_dict