MessagePassing in Python geometric
Convolution computation in the graph is usually called neighborhood aggregation or message passing. Definition
x
i
(
k
−
1
)
∈
R
F
\mathbf x^{(k-1)}_i \in R^{F}
xi(k − 1) ∈ RF is the node
i
i
i on page
(
k
−
1
)
(k-1)
Characteristics of (k − 1) layer,
e
j
,
i
\mathbf e_{j,i}
ej,i , indicates the node
j
j
j to node
i
i
In GNN, message pass i ng can be expressed as
x
i
(
k
)
=
γ
(
k
)
(
x
i
(
k
−
1
)
,
□
j
∈
N
(
i
)
ϕ
(
k
)
(
x
i
(
k
−
1
)
,
x
j
(
k
−
1
)
,
e
j
,
i
)
)
\mathbf x_{i}^{(k)} = \gamma^{(k)} \left(\mathbf x_{i}^{(k-1)}, \square_{j \in N(i)} \phi^{(k)} \left(\mathbf x_{i}^{(k-1)}, \mathbf x_{j}^{(k-1)}, \mathbf e_{j,i} \right) \right)
xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i))
among
□
\square
□ represents functions that are permutation invariant and differentiable, such as sum, mean, max, etc,
γ
\gamma
γ and
ϕ
\phi
ϕ Represents a differentiable function.
In pytorch geometric, all convolution operators are derived from the MessagePassing class. Understanding messagepaging helps us understand the calculation method of message passing in PyG and write custom convolution. In custom convolution, the user only needs to define the message passing function ϕ \phi ϕ message(), node update function γ \gamma γ update() and aggregation method aggr='add', aggr='mean' or aggr=max. the specific functions are described as follows:
- MessagePassing(aggr='add', flow='source_to_target', node_dim=-2) defines the aggregation calculation method ('add ',' mean 'or max) and the message delivery direction (source_to_target or target_to_source). In PyG, the central node is the target target and the neighborhood node is the source. node_dim is the dimension of message aggregation
- MessagePassing.propagate(edge_index, size=None, **kwargs): this function accepts edge information_ Index and other additional data to perform messaging and update node embedding.
- MessagePassing.message(...): this function is used to calculate node messages, which is a function in the formula ϕ \phi ϕ . If flow='source_to_target ', then the message will be sent by the neighborhood node j j j to the central node i i i ; If flow='target_to_source ', the message is sent by the central node i i i pass to neighborhood node j j j. the node type of the incoming parameter can be determined by the suffix of the variable name. For example, the central node embeds the variable in_ i is the end, and the neighborhood node embeds the variable with x_j is the end
- MessagePassing.update(arr_out,...): this function is the node embedded update function γ \gamma γ , The input parameter is the result calculated by the aggregate function MessagePassing.aggregate
In order to better understand the calculation process of MessagePassing in PyG, let's analyze the source code.
class MessagePassing(torch.nn.Module): special_args: Set[str] = { 'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j', 'ptr', 'index', 'dim_size' } def __init__(self, aggr: Optional[str] = "add", flow: str = "source_to_target", node_dim: int = -2): super(MessagePassing, self).__init__() self.aggr = aggr assert self.aggr in ['add', 'mean', 'max', None] self.flow = flow assert self.flow in ['source_to_target', 'target_to_source'] self.node_dim = node_dim self.inspector = Inspector(self) self.inspector.inspect(self.message) self.inspector.inspect(self.aggregate, pop_first=True) self.inspector.inspect(self.message_and_aggregate, pop_first=True) self.inspector.inspect(self.update, pop_first=True) self.__user_args__ = self.inspector.keys( ['message', 'aggregate', 'update']).difference(self.special_args) self.__fused_user_args__ = self.inspector.keys( ['message_and_aggregate', 'update']).difference(self.special_args) # Support for "fused" message passing. self.fuse = self.inspector.implements('message_and_aggregate') # Support for GNNExplainer. self.__explain__ = False self.__edge_mask__ = None
In the initialization function, MessagePassing defines an inspector. The Chinese meaning of inspector is inspector. The function of this class is to check the input parameters of each function and save them to Inspector.params in the parameter list Dictionary of inspector. If the input parameter of message is x_i, x_j. Then Inspector.params['message '] = {x_i': parameter, 'x_j': parameter} (Note: This is for illustration only, and the actual Inspector.params['message '] type is OrderedDict). Inspector.implements checks whether the function is implemented
The core of messagepaging is the propgate function, which assumes the adjacency matrix edge_ The type of index is Torch.LongTensor, and the message is sent by edge_index[0] to edge_index[1], the code implementation is as follows
def propagate(self, edge_index: Adj, size: Size = None, **kwargs): # To simplify the problem, edge is not discussed here_ If the index is SparseTensor, you can read the PyG original code if you are interested size = self.__check_input__(edge_index, size) coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs) msg_kwargs = self.inspector.distribute('message', coll_dict) out = self.message(**msg_kwargs) aggr_kwargs = self.inspector.distribute('aggregate', coll_dict) out = self.aggregate(out, **aggr_kwargs) update_kwargs = self.inspector.distribute('update', coll_dict) return self.update(out, **update_kwargs)
In this code, first check the number of nodes and user-defined input variables, and then execute the message, aggregate and update functions in turn. In case of custom graph convolution, message and update will be rewritten. This will be explained later by taking GCN as an example. Here, let's take a look at the implementation of aggregate
def aggregate(self, inputs: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None) -> Tensor: if ptr is not None: ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim()) return segment_csr(inputs, ptr, reduce=self.aggr) else: return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
The ptr variable is for the adjacency matrix edge_ If the index is SparseTensor, it is ignored here. inputs is the message calculated from message. Index is the index of the node to be updated, which is actually edge_index_i. Aggregation calculation is realized by scatter function. Scatter specific implementation reference link
Below with GCN For example, let's take a look at the calculation process of MessagePassing. GCN is calculated as follows:
x
i
(
k
)
=
∑
j
∈
N
(
i
)
∪
{
i
}
1
deg
(
i
)
⋅
deg
(
j
)
⋅
(
Θ
⋅
x
j
(
k
−
1
)
)
,
\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right),
xi(k)=j∈N(i)∪{i}∑deg(i)
⋅deg(j)
1⋅(Θ⋅xj(k−1)),
The actual calculation project can be divided into the following steps:
- Self circulation is added to the adjacency matrix, that is, the elements on the diagonal of the adjacency matrix are set to 1
- Linear transformation of node characteristic matrix
- Calculate the normalization coefficient of the node, that is, the square of the node degree product
- Normalize the node features
- Aggregate (sum) node features to get a new node embedding
The code is as follows
import torch from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degree class GCNConv(MessagePassing): def __init__(self, in_channels, out_channels): super(GCNConv, self).__init__(aggr='add') # "Add" aggregation (Step 5). self.lin = torch.nn.Linear(in_channels, out_channels) def forward(self, x, edge_index): # x has shape [N, in_channels] # edge_index has shape [2, E] # Step 1: Add self-loops to the adjacency matrix. edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # Step 2: Linearly transform node feature matrix. x = self.lin(x) # Step 3: Compute normalization. row, col = edge_index deg = degree(col, x.size(0), dtype=x.dtype) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # Step 4-5: Start propagating messages. return self.propagate(edge_index, x=x, norm=norm) def message(self, x_j, norm): # x_j has shape [E, out_channels] # Step 4: Normalize node features. return norm.view(-1, 1) * x_j
In the forward function, the first thing is to add self circulation to the nodes and edges. Set the input variables as follows
edge_index = torch.tensor([[0, 0, 2], [1, 2, 3]], dtype=torch.long) x = torch.rand((4, 3)) conv = GCNConv(3, 8)
Note that the default messaging direction is source_to_target, edge_index[0]=x_j is source, edge_index[1]=x_i is target. In GCN, the first step is to increase the node's self circulation, add_ self_ The changes before and after loops calculation are as follows
# before add_self_loops # edge_index= tensor([[0, 0, 2], [1, 2, 3]]) # after add_self_loops # edge_index= tensor([[0, 0, 2, 0, 1, 2, 3], [1, 2, 3, 0, 1, 2, 3]]) # norm= tensor([0.7071, 0.7071, 0.5000, 1.0000, 0.5000, 0.5000, 0.5000]
The output parameter of propagate here is determined by edge_index, x, norm , edge_index is a parameter that must be entered by propagete. X and norm are user-defined parameters. In__ collect__ The input parameters required by message are collected according to the variable name. In GCN, norm remains unchanged and X will be mapped to x_j. And after__ lift__ Function, its value will also change__ lift__ The function is as follows
def __lift__(self, src, edge_index, dim): if isinstance(edge_index, Tensor): index = edge_index[dim] return src.index_select(self.node_dim, index)
In this example, the entered feature shape=[4, 8] is passed__ lift__ After, the node feature shape=[7, 8]. After message calculation, aggregate and update can be executed.