Road to ML Engineer #34 - MPNNs & Graph Transformers

Last Edited: 1/15/2025

The blog post discusses about message passing neural networks, graph transformers, and more.

ML

In the past two articles, we have been increasing the complexity of the model from GCN to GAT. In this article, we will discuss even more complex models such as message passing neural networks and graph transformers.

Message Passing Neural Networks

The models we have been covering so far have been utilizing the node features, XX, and learning the scalar weights WW and a\overrightharpoon{a} for GAT. However, depending on the graphs, we may have edge features ee, which we can utilize to make the model more expressive. Gilmer, J. et al. (2017) introduced message passing neural networks (MPNN), which take the relevant node features and edge features to produce latent node representations, enabling the model to tackle tasks involving graphs where edges encode different chemical bonds and spatial distances.

mit+1=j{i,Ni}Mt(hit,hjt,eij)hit+1=Ut(hit,mit+1) m_i^{t+1} = \sum_{j \in \{i, N_i\}} M_t(h_i^t, h_j^t, e_{ij}) \\ h_i^{t+1} = U_t(h_i^t, m_i^{t+1})

The above is the message passing phase of the MPNN. MtM_t is the message passing function, and UtU_t is the update function, both of which can be tailored to the specific task at hand. For the message passing function, Gilmer, J. et al. (2017) suggest an edge network, σ(Weij)hj\sigma(We_{ij})h_j, and a simple feedforward layer, σ(W[hihjeij])\sigma(W[h_i || h_j || e_{ij}]). The update function can be another feedforward layer, σ(W[himit+1])\sigma(W[h_i || m_i^{t+1}]). In addition to the message passing phase, the original MPNN incorporates a readout phase, which applies a permutation-invariant function RR to latent node representations to generate a graph-wise latent representation.

Code Implementation

Below is an example TensorFlow implementation of an MPNN layer with both MM and UU being simple feedforward layers that operate on the concatenated features.

class MPNNLayer(layers.Layer):
    def __init__(self, d_in, d_edge, d_latent, d_out):
        super(MPNNLayer, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.W_m = self.add_weight(shape=(2 * d_in + d_edge, d_latent),
                                    initializer='glorot_uniform',
                                    trainable=True)
        self.W_u = self.add_weight(shape=(d_in + d_latent, d_out),
                                    initializer='glorot_uniform',
                                    trainable=True)
 
    def call(self, x, A):
        # x: (batch_size, n, d), A: (batch_size, n, n, d_edge)
        # Message Passsing
        x_i = tf.tile(tf.expand_dims(x, 1), [1, tf.shape(x)[1], 1, 1])
        x_j = tf.transpose(x_i, [0, 2, 1, 3])
        x_ij = tf.concat([x_i, x_j, A], -1) # => (batch_size, n, n, 2d_in + d_edge)
        m = tf.matmul(x_ij, self.W_m) # => (batch_size, n, n, d_latent)
        m = tf.nn.relu(m)
        m = tf.reduce_sum(m, 2) # => (batch_size, n, d_latent)
        
        # Updating
        x = tf.concat([x, m], -1) # => (batch_size, n, d_in + d_latent)
        x = tf.matmul(x, self.W_u) # => (batch_size, n, d_out)
        x = tf.nn.relu(x)
        return x

You can try implementing various message passing and update functions in PyTorch for practice.

Complexity of MPNNs

You might have realized, by looking at the definition of MPNN and the code implementation, that GCN and GAT are also specific instances of MPNN. For example, the below shows the message passing function and update function of GCN.

M(hi,hj,eij)=j{Ni}1didjWhjUt(hit,mit+1)=1diWhi+mit+1 M(h_i, h_j, e_{ij}) = \sum_{j \in \{N_i\}} \frac{1}{\sqrt{d_id_j}} W h_j \\ U_t(h_i^t, m_i^{t+1}) = \frac{1}{d_i} W h_i + m_i^{t+1}

They do not use edge features and instead multiply the node features by scalars that vary and/or do not vary depending on the node pairs, making them less expressive than the MPNN implemented above. However, more complex models do not necessarily yield better results, as they tend to face scalability and learnability issues. Hence, it is important to choose the right degree of complexity and expressivity for the model depending on the task at hand.

How Powerful Are MPNNs?

Xu, K. et al. (2019) provided proof that an MPNN on discrete node features (without edge features) can be as powerful as the 1-WL test (which is unfortunately outside the scope of this article) in discriminating graphs. They defined the simplest and most powerful message passing and update functions, which utilize an injective aggregator or sum as follows:

M(hi,hj,eij)=j{Ni}hjUt(hit,mit+1)=σ(W((1+ϵ)hit+mit+1)) M(h_i, h_j, e_{ij}) = \sum_{j \in \{N_i\}} h_j \\ U_t(h_i^t, m_i^{t+1}) = \sigma(W((1+\epsilon)h_i^t+m_i^{t+1}))

The model with the above functions is called the Graph Isomorphism Network (GIN), and ϵ\epsilon is a constant or learnable scalar. In the case of continuous node features, however, Corso, G. et al. (2020) proved that there is no single aggregator that can match the WL test and proposed Principal Neighborhood Aggregation (PNA), an empirically powerful combination of aggregators for general-purpose GNNs.

=[IS(D,α=1)S(D,α=1)][μσmaxmin] \bigoplus = \begin{bmatrix} I\\ S(D, \alpha=1) \\ S(D, \alpha=-1) \\ \end{bmatrix} \otimes \begin{bmatrix} \mu \\ \sigma \\ max \\ min \\ \end{bmatrix}

Where SS is the degree-dependent scaler, and \otimes is the tensor product. The aggregator can be used within the message-passing function to aggregate the results of a neural network, which can be passed to another neural network for the update function. I recommend checking the original paper listed at the bottom of the article for details of these operations.

Oversmoothing Problem

Even when a maximally powerful aggregator is utilized, the oversmoothing problem—where the latent node representations converge to the same value for all nodes after a series of aggregations—remains relevant. This can be addressed by several methods such as making the network shallow, increasing the complexity of the layers, and adding skip connections. However, these methods require researchers to make careful, educated guesses when deciding the hyperparameters of the models and prevent the model from easily scaling up, motivating the use of alternative approaches.

Graph Transformers

As an alternative approach to MPNNs with the oversmoothing problem, scalability, and expressivity, graph transformers are emerging as a new area of research. Node and edge features can be treated as tokens, appropriate positional and structural encodings can be generated, and these tokens can be passed to transformers to produce latent node and edge representations.

TokenGT

Kim, J. et al. (2022) proposed the Tokenized Graph Transformer (TokenGT), which utilizes orthonormal vectors PP for each node as node identifiers to encode the structural information of the tokens. For a node vv, PvP_v can be appended twice to the node embedding to create [XvPvPv][X_v || P_v || P_v], and for an edge XuvX_{uv}, PuP_u and PvP_v can be appended to the edge embedding to create [XuvPuPv][X_{uv} || P_u || P_v]. Additionally, TokenGT appends trainable type vectors EVE^V for vertices and EEE^E for edges. The orthonormal vectors can either be randomly generated or obtained by eigendecomposition of the normalized graph Laplacian matrix (a matrix reflecting the connectivity and clustering of the graph, which is unfortunately outside the scope of this article but may be covered in the future).

class GTTokenizer(layers.Layer):
    def __init__(self, n, d, latent):
        super(GTTokenizer, self).__init__()
        self.n = n
        self.d = d
        self.latent = latent
        # Learnable Type Identifiers
        self.e_v = self.add_weight(shape=(latent,),
                                    initializer='glorot_uniform',
                                    trainable=True)
        self.e_e = self.add_weight(shape=(latent,),
                                    initializer='glorot_uniform',
                                    trainable=True)
        # graph token
        self.graph_token = self.add_weight(
            shape=(1, 1, d+2*n+latent),
            initializer='glorot_uniform',
            name="graph_token",
        )
 
    def call(self, x, e, e_id, A):
        # x: (batch_size, n, d), e: (batch_size, m, d), e_id: (batch_size, m, 2), A: (batch_size, n, n)
        # Type Identifiers
        e_v = tf.tile(tf.expand_dims(tf.expand_dims(self.e_v, 0), 0), [tf.shape(x)[0], tf.shape(x)[1], 1])
        e_e = tf.tile(tf.expand_dims(tf.expand_dims(self.e_e, 0), 0), [tf.shape(e)[0], tf.shape(e)[1], 1])
 
        # Node Identifiers
        I = tf.eye(self.n)
        D = tf.reduce_sum(A, axis=-1)
        D = tf.linalg.diag(D)
        D_inv_sqrt = tf.linalg.inv(tf.sqrt(D))
        L = I - tf.matmul(tf.matmul(D_inv_sqrt, A), D_inv_sqrt) # normalized Laplacian
        _,v = tf.linalg.eigh(L) # Laplacian eigenvectors
 
        # For nodes
        x = tf.concat([x, v, v, e_v], -1)
 
        # For edges
        e_id_1, e_id_2 = tf.unstack(e_id, axis=-1)
        e_id_1 = tf.expand_dims(e_id_1, -1)
        e_id_2 = tf.expand_dims(e_id_2, -1)
        v_1 = tf.gather_nd(v, e_id_1, batch_dims=1)
        v_2 = tf.gather_nd(v, e_id_1, batch_dims=1)
        e = tf.concat([e, v_1, v_2, e_e], -1)
 
        # Token Generation
        graph_tokens = tf.repeat(self.graph_token, tf.shape(x)[0], axis=0)
        tokens = tf.concat([x, e, graph_tokens], 1) #=> (batch_size, (n+m+1), d+2*n+latent)
        return tokens

The above is the TensorFlow implementation of a tokenizer for TokenGT, using Laplacian eigenvectors as node identifiers. The tokens can be passed to the transformer encoder, whose output can then be used for downstream tasks. Rampasek, L. et al. (2022) proposed GraphGPS, which stands for General, Powerful, and Scalable architecture for graph transformers. It combines various positional and structural encodings, including the ones described above, MPNNs, and graph transformers in modular, scalable, and expressive ways. You should have most of the prerequisite knowledge, so I recommend checking the original paper cited at the bottom of the article if you are interested.

Conclusion

In this article, we covered the definition of MPNNs, the code implementation of more complex MPNNs, maximally powerful MPNNs for discrete and continuous node features, and the oversmoothing problem, which motivated research into graph transformers. I might discuss spectral graph theory (where the graph Laplacian matrix is introduced), the WL test, and more about GIN, PNA, and graph transformers in the future, but I hope you got the gist of the field with this article.

Resources