Road to ML Engineer #33 - Graph Attention Networks

Last Edited: 1/10/2025

The blog post discusses about graph attention networks.

ML

In the last article, we discussed how graph convolution works and how it is simple and effective for homophilous graphs. In this article, we will discuss more complex models that can handle non-homophilous graphs as well: Graph Attention Networks.

Graph Attention

The simplicity of graph convolution is rooted in applying the same weight to every pair of nodes, which is necessary to maintain permutation equivariance. However, there is one factor that could be changed in graph convolution: the normalization factor. GCN uses D1D^{-1} (or D12D^{\frac{1}{2}}) for normalization, whose values remain constant. Instead, we can use learnable weights or attention assigned to each node pair, making the model more complex and capable of learning intricate relationships.

eij=LeakyReLU(aT[WhiWhj])αij=softmaxj(eij)=exp(eij)kNiexp(eik) e_{ij} = \text{LeakyReLU}(\overrightharpoon{a}^T[Wh_i \Vert Wh_j]) \\ \alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\text{exp}(e_{ij})}{\sum_{k \in N_i}\text{exp}(e_{ik})}

The computation of the attention coefficient for a pair eije_{ij} involves concatenating (\Vert) WhiWh_i and WhjWh_j, then passing it through a feedforward layer with weight vector a\overrightharpoon{a} and a LeakyReLU activation function. The coefficients of all the neighbors of ii (including ii itself) are passed to a softmax function to adjust the range of attention αij\alpha_{ij} between 0 and 1. Below is a forceful representation of the corresponding tensor operation.

E=LeakyReLU([[HW]×n([HW]×n)T]a)A=softmax(A~E) E = \text{LeakyReLU}([[HW]_{\times n} \Vert ([HW]_{\times n})^T]\overrightharpoon{a}) \\ \Alpha = \text{softmax}(\tilde{A}E)

HW×nHW_{\times n} represents a 3D tensor that repeats HWHW for nn times, resulting in a size of (n,n,dout)(n, n, d_{out}). The tensor and its transpose (where the first and second axes are swapped to create a tensor of size (n,n,dout)(n, n, d_{out})) are concatenated along the third axis to form a tensor of size (n,n,2dout)(n, n, 2d_{out}). This tensor is then multiplied by a\overrightharpoon{a} of size (2dout,1)(2d_{out}, 1) and passed through LeakyReLU to compute the attention coefficient matrix EE of size (n,n)(n, n) (after flattening the third axis).

Finally, we apply a mask by multiplying EE with A~(=A+I)\tilde{A} (=A+I) and pass it through a softmax function to get the masked attention matrix A\Alpha (alpha) of size (n,n)(n, n). Note that the computation of attention maintains the ordering of the attention rows and columns, ensuring permutation equivariance.

Graph Attention Networks

In the previous section, we successfully set up a permutation-equivariant method of computing learnable graph attention, where values can differ between different pairs of nodes. Graph Attention Networks (GATs) replace the D1D^{-1} term in GCN with the resulting A\Alpha, allowing the model to generate appropriate embeddings for downstream tasks, even for non-homophilous graphs.

hi=g(xi,XNi)=σ(j{i,Ni}αijWxj)H=F(X,A)=[g(x1,XN1)g(x2,XN2)g(xv,XNv)]=σ(AA~XW) h_i = g(x_i, X_{N_i}) = \sigma(\sum_{j \in \{i, N_i\}} \alpha_{ij}Wx_j) \\ H = F(X, A) = \begin{bmatrix} - g(x_1, X_{N_1}) -\\ - g(x_2, X_{N_2}) -\\ \vdots \\ - g(x_v, X_{N_v}) -\\ \end{bmatrix} = \sigma(\Alpha\tilde{A}XW)

Additionally, GATs utilize multi-head attention, similar to transformers, to handle graphs with complex relationships. Each attention head has its own aa and ww. The output latent embeddings from all heads are concatenated, except for the last layer, where they are averaged.

hi=k=1Kσ(j{i,Ni}αijkWkxj)hi=σ(1Kk=1Kj{i,Ni}αijkWkxj) h_i = \Big\Vert_{k=1}^{K}\sigma(\sum_{j \in \{i, N_i\}} \alpha_{ij}^kW^kx_j) \\ h_i = \sigma(\frac{1}{K}\sum_{k=1}^K\sum_{j \in \{i, N_i\}} \alpha_{ij}^kW^kx_j)

You may notice that the computation is very similar to transformer attention. The primary difference lies in the use of concatenation and a feedforward layer for computing the attention coefficient in graph attention, instead of the normalized dot product used in transformer attention. (The use of the adjacency matrix is unique to graph attention, but transformers can be interpreted as operating on a complete graph of tokens.)

Code Implementation

We can first set up single-head attention and combine it later to implement multi-head attention. Below is the TensorFlow implementation of single-head attention.

class GraphAttention(layers.Layer):
    def __init__(self, d_in, d_out):
        super(GraphAttention, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.W = self.add_weight(shape=(d_in, d_out),
                                       initializer='glorot_uniform',
                                       trainable=True)
        self.a = self.add_weight(shape=(2*d_out, 1),
                                       initializer='glorot_uniform',
                                       trainable=True)
        self.leaky_relu = layers.LeakyReLU(0.2)
 
    def call(self, x, A):
        # x: (batch_size, n, d_in), A: (batch_size, n, n)
        x = tf.matmul(x, self.W)
        x_i = tf.tile(tf.expand_dims(x, 1), [1, tf.shape(x)[1], 1, 1])
        x_j = tf.transpose(x_i, [0, 2, 1, 3])
        x_ij = tf.concat([x_i, x_j], -1)
 
        e = tf.squeeze(tf.matmul(x_ij, self.a), [-1])
        e = self.leaky_relu(e)
 
        attn = tf.nn.softmax(tf.matmul(A, e))
        x = tf.matmul(A, x)
        x = tf.matmul(attn, x)
        return x

Instead of using a weight matrix and vector for the feedforward layer, you can apply Conv1D with an appropriate number of filters to achieve the same effect (as done by the authors of the original paper). Specifically, you can separate a\overrightharpoon{a} into [a1a2][a_1||a_2] and rewrite aT[WhiWhj]\overrightharpoon{a}^T[Wh_i || Wh_j] as a1TWhi+a2TWhja_1^TWh_i + a_2^TWh_j, which can be implemented using a single Conv1D layer with a kernel size of 1.

You can also incorporate dropout layers for the input and attention coefficients, as well as residual connections, to make the model more robust. The nonlinearity σ\sigma is not applied to single-head attention because it is applied in the multi-head attention layer.

class MultiHeadGraphAttention(layers.Layer):
    def __init__(self, num_heads, d_in, d_out, predict=False):
        super(MultiHeadGraphAttention, self).__init__()
        self.num_heads = num_heads
        self.attention_heads = []
        self.predict = predict
        for _ in range(num_heads):
            self.attention_heads.append(GraphAttention(d_in, d_out))
 
    def call(self, x, A):
        head_outputs = [head(x, A) for head in self.attention_heads]
        if self.predict:
          head_outputs = tf.reduce_mean(head_outputs, axis=0)
        else:
          head_outputs = tf.concat(head_outputs, [-1])
        return tf.nn.relu(head_outputs)

The implementation of multi-head graph attention in TensorFlow is provided above, which uses the ReLU activation function for σ\sigma. You can try experimenting with other activation functions or parameters, and implement it in PyTorch for practice. The resulting multi-head attention can be stacked to create GATs, which produce latent embeddings for downstream tasks.

Static vs Dynamic Attention

If you looked at the attention computation in the above section and noticed something, you might want to stop reading this article and become an ML researcher already. The above GAT computes static attention, where the order of nodes with high to low attention does not vary depending on the query node. This is because weights a\overrightharpoon{a} and WW are collapsible into one linear operation, and both LeakyReLU and softmax are monotonic functions, which makes the attention monotonic with respect to hjh_j. The values of attention can vary by hih_i as well, but the order of attention remains the same across all query nodes and only changes the sharpness.

Static Attention

Brody, S. et al. (2022) highlighted this issue with an illustration above that explains static attention well. The example shows attention computed with a complete bipartite graph of 9 query and key nodes using GAT. We observe that the lines for all the query nodes have the same shape, with k8 having the highest attention values. Ideally, we would want dynamic attention, where the rank of attention can vary depending on the query node, to account for more complex relationships. (Even though static attention might work better than dynamic attention for simpler graphs in practice.)

Graph Attention Network v2

The biggest problem lies in collapsible a\overrightharpoon{a} and WW, which prevent the model from learning non-linear relationships. Brody, S. et al. (2022) addressed this issue by reordering the operations as follows:

eij=aTLeakyReLU(W[hihj])αij=softmaxj(eij)=exp(eij)kNiexp(eik) e_{ij} = \overrightharpoon{a}^T\text{LeakyReLU}(W[h_i \Vert h_j]) \\ \alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\text{exp}(e_{ij})}{\sum_{k \in N_i}\text{exp}(e_{ik})}

The revised approach first concatenates hih_i and hjh_j, applies WW, and then applies nonlinearity before multiplying by aT\overrightharpoon{a}^T. This change ensures that the weights are not collapsible and are capable of learning non-linear relationships. Brody, S. et al. (2022) confirmed that this modification results in achieving dynamic attention.

class GraphAttentionV2(layers.Layer):
    def __init__(self, d_in, d_out):
        super(GraphAttentionV2, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.W = self.add_weight(shape=(d_in, d_out),
                                       initializer='glorot_uniform',
                                       trainable=True)
        self.a = self.add_weight(shape=(d_out, 1),
                                       initializer='glorot_uniform',
                                       trainable=True)
        self.leaky_relu = layers.LeakyReLU(0.2)
 
    def call(self, x, A):
        # x: (batch_size, n, d), A: (batch_size, n, n)
        x_i = tf.tile(tf.expand_dims(x, 1), [1, tf.shape(x)[1], 1, 1])
        x_j = tf.transpose(x_i, [0, 2, 1, 3])
        x_ij = tf.concat([x_i, x_j], -1)
        w = tf.concat([self.W, self.W], 0)
        x_ij = tf.matmul(x_ij, w)
        x_ij = self.leaky_relu(x_ij)
        
        e = tf.squeeze(tf.matmul(x_ij, self.a), [-1])
 
        attn = tf.nn.softmax(tf.matmul(A, e))
        x = tf.matmul(A, x)
        x = tf.matmul(attn, x)
        x = tf.matmul(x, self.W)
        return x

The above is the code implementation of single-head graph attention used in GATv2. As you can see, the only difference from the original GAT is the order of operations, stacking of W, and the subsequent change in the dimensions of a. The code for multi-head attention remains unchanged except for substituting the single-head attention.

Conclusion

In this article, we replaced the normalization factor of GCN with attention to create GAT, which can work well even for non-homophilous graphs while remaining simple and scalable. We also discovered that the original GAT uses static attention and introduced GATv2, which achieves dynamic attention and can handle even more complex relationships. In the next article, we will explore an even more complex and expressive model for graphs.

Resources