The blog post discusses about message passing neural networks, graph transformers, and more.

In the past two articles, we have been increasing the complexity of the model from GCN to GAT. In this article, we will discuss even more complex models such as message passing neural networks and graph transformers.
Message Passing Neural Networks
The models we have been covering so far have been utilizing the node features, , and learning the scalar weights and for GAT. However, depending on the graphs, we may have edge features , which we can utilize to make the model more expressive. Gilmer, J. et al. (2017) introduced message passing neural networks (MPNN), which take the relevant node features and edge features to produce latent node representations, enabling the model to tackle tasks involving graphs where edges encode different chemical bonds and spatial distances.
The above is the message passing phase of the MPNN. is the message passing function, and is the update function, both of which can be tailored to the specific task at hand. For the message passing function, Gilmer, J. et al. (2017) suggest an edge network, , and a simple feedforward layer, . The update function can be another feedforward layer, . In addition to the message passing phase, the original MPNN incorporates a readout phase, which applies a permutation-invariant function to latent node representations to generate a graph-wise latent representation.
Code Implementation
Below is an example TensorFlow implementation of an MPNN layer with both and being simple feedforward layers that operate on the concatenated features.
class MPNNLayer(layers.Layer):
def __init__(self, d_in, d_edge, d_latent, d_out):
super(MPNNLayer, self).__init__()
self.d_in = d_in
self.d_out = d_out
self.W_m = self.add_weight(shape=(2 * d_in + d_edge, d_latent),
initializer='glorot_uniform',
trainable=True)
self.W_u = self.add_weight(shape=(d_in + d_latent, d_out),
initializer='glorot_uniform',
trainable=True)
def call(self, x, A):
# x: (batch_size, n, d), A: (batch_size, n, n, d_edge)
# Message Passsing
x_i = tf.tile(tf.expand_dims(x, 1), [1, tf.shape(x)[1], 1, 1])
x_j = tf.transpose(x_i, [0, 2, 1, 3])
x_ij = tf.concat([x_i, x_j, A], -1) # => (batch_size, n, n, 2d_in + d_edge)
m = tf.matmul(x_ij, self.W_m) # => (batch_size, n, n, d_latent)
m = tf.nn.relu(m)
m = tf.reduce_sum(m, 2) # => (batch_size, n, d_latent)
# Updating
x = tf.concat([x, m], -1) # => (batch_size, n, d_in + d_latent)
x = tf.matmul(x, self.W_u) # => (batch_size, n, d_out)
x = tf.nn.relu(x)
return x
You can try implementing various message passing and update functions in PyTorch for practice.
Complexity of MPNNs
You might have realized, by looking at the definition of MPNN and the code implementation, that GCN and GAT are also specific instances of MPNN. For example, the below shows the message passing function and update function of GCN.
They do not use edge features and instead multiply the node features by scalars that vary and/or do not vary depending on the node pairs, making them less expressive than the MPNN implemented above. However, more complex models do not necessarily yield better results, as they tend to face scalability and learnability issues. Hence, it is important to choose the right degree of complexity and expressivity for the model depending on the task at hand.
How Powerful Are MPNNs?
Xu, K. et al. (2019) provided proof that an MPNN on discrete node features (without edge features) can be as powerful as the 1-WL test (which is unfortunately outside the scope of this article) in discriminating graphs. They defined the simplest and most powerful message passing and update functions, which utilize an injective aggregator or sum as follows:
The model with the above functions is called the Graph Isomorphism Network (GIN), and is a constant or learnable scalar. In the case of continuous node features, however, Corso, G. et al. (2020) proved that there is no single aggregator that can match the WL test and proposed Principal Neighborhood Aggregation (PNA), an empirically powerful combination of aggregators for general-purpose GNNs.
Where is the degree-dependent scaler, and is the tensor product. The aggregator can be used within the message-passing function to aggregate the results of a neural network, which can be passed to another neural network for the update function. I recommend checking the original paper listed at the bottom of the article for details of these operations.
Oversmoothing Problem
Even when a maximally powerful aggregator is utilized, the oversmoothing problem—where the latent node representations converge to the same value for all nodes after a series of aggregations—remains relevant. This can be addressed by several methods such as making the network shallow, increasing the complexity of the layers, and adding skip connections. However, these methods require researchers to make careful, educated guesses when deciding the hyperparameters of the models and prevent the model from easily scaling up, motivating the use of alternative approaches.
Graph Transformers
As an alternative approach to MPNNs with the oversmoothing problem, scalability, and expressivity, graph transformers are emerging as a new area of research. Node and edge features can be treated as tokens, appropriate positional and structural encodings can be generated, and these tokens can be passed to transformers to produce latent node and edge representations.

Kim, J. et al. (2022) proposed the Tokenized Graph Transformer (TokenGT), which utilizes orthonormal vectors for each node as node identifiers to encode the structural information of the tokens. For a node , can be appended twice to the node embedding to create , and for an edge , and can be appended to the edge embedding to create . Additionally, TokenGT appends trainable type vectors for vertices and for edges. The orthonormal vectors can either be randomly generated or obtained by eigendecomposition of the normalized graph Laplacian matrix (a matrix reflecting the connectivity and clustering of the graph, which is unfortunately outside the scope of this article but may be covered in the future).
class GTTokenizer(layers.Layer):
def __init__(self, n, d, latent):
super(GTTokenizer, self).__init__()
self.n = n
self.d = d
self.latent = latent
# Learnable Type Identifiers
self.e_v = self.add_weight(shape=(latent,),
initializer='glorot_uniform',
trainable=True)
self.e_e = self.add_weight(shape=(latent,),
initializer='glorot_uniform',
trainable=True)
# graph token
self.graph_token = self.add_weight(
shape=(1, 1, d+2*n+latent),
initializer='glorot_uniform',
name="graph_token",
)
def call(self, x, e, e_id, A):
# x: (batch_size, n, d), e: (batch_size, m, d), e_id: (batch_size, m, 2), A: (batch_size, n, n)
# Type Identifiers
e_v = tf.tile(tf.expand_dims(tf.expand_dims(self.e_v, 0), 0), [tf.shape(x)[0], tf.shape(x)[1], 1])
e_e = tf.tile(tf.expand_dims(tf.expand_dims(self.e_e, 0), 0), [tf.shape(e)[0], tf.shape(e)[1], 1])
# Node Identifiers
I = tf.eye(self.n)
D = tf.reduce_sum(A, axis=-1)
D = tf.linalg.diag(D)
D_inv_sqrt = tf.linalg.inv(tf.sqrt(D))
L = I - tf.matmul(tf.matmul(D_inv_sqrt, A), D_inv_sqrt) # normalized Laplacian
_,v = tf.linalg.eigh(L) # Laplacian eigenvectors
# For nodes
x = tf.concat([x, v, v, e_v], -1)
# For edges
e_id_1, e_id_2 = tf.unstack(e_id, axis=-1)
e_id_1 = tf.expand_dims(e_id_1, -1)
e_id_2 = tf.expand_dims(e_id_2, -1)
v_1 = tf.gather_nd(v, e_id_1, batch_dims=1)
v_2 = tf.gather_nd(v, e_id_1, batch_dims=1)
e = tf.concat([e, v_1, v_2, e_e], -1)
# Token Generation
graph_tokens = tf.repeat(self.graph_token, tf.shape(x)[0], axis=0)
tokens = tf.concat([x, e, graph_tokens], 1) #=> (batch_size, (n+m+1), d+2*n+latent)
return tokens
The above is the TensorFlow implementation of a tokenizer for TokenGT, using Laplacian eigenvectors as node identifiers. The tokens can be passed to the transformer encoder, whose output can then be used for downstream tasks. Rampasek, L. et al. (2022) proposed GraphGPS, which stands for General, Powerful, and Scalable architecture for graph transformers. It combines various positional and structural encodings, including the ones described above, MPNNs, and graph transformers in modular, scalable, and expressive ways. You should have most of the prerequisite knowledge, so I recommend checking the original paper cited at the bottom of the article if you are interested.
Conclusion
In this article, we covered the definition of MPNNs, the code implementation of more complex MPNNs, maximally powerful MPNNs for discrete and continuous node features, and the oversmoothing problem, which motivated research into graph transformers. I might discuss spectral graph theory (where the graph Laplacian matrix is introduced), the WL test, and more about GIN, PNA, and graph transformers in the future, but I hope you got the gist of the field with this article.
Resources
- Corso, G. et al. 2020. Principal Neighbourhood Aggregation for Graph Nets. ArXiv.
- Fourrier, C. 2023. Introduction to Graph Machine Learning. HuggingFace.
- Gilmer, J. et al. 2017. Neural Message Passing for Quantum Chemistry. ArXiv.
- Kim, J. et al. 2022. Pure Transformers are Powerful Graph Learners. ArXiv.
- Rampasek, L. et al. 2022. Recipe for a General, Powerful, Scalable Graph Transformer. ArXiv.
- Velickovic, P. 2021. Theoretical Foundations of Graph Neural Networks. YouTube.