Deep Learning/Graph Neural Network

Spatial Graph Convolution Network based on MPNN

언킴 2022. 6. 16. 01:04
반응형

Contents

     

    BackGround

    Convolution Graph Neural Network (ConvGNN)은 GCN이라고도 부르며, 이를 알기 위해서는 GNN이 무엇인지, 더 나아가 왜 기존 머신러닝을 사용하지 않고 Graph Mining을 사용하게 되었는지 짚고 넘어가야 한다. 기존의 머신러닝은 Euclidean data에서는 매우 우수한 성능을 발휘하고 있었으나, non-Euclidean data에는 적용하기 어렵다는 문제점이 존재했다. non-Euclidean data의 대표적인 예로는 NLP 구조와 Tree, Graph 구조가 있다. 

    좌: Tree 구조, 우: Graph 구조

     

    위의 그림 처럼 Tree, Graph와 같은 구조는 각 feature 간의 relation이 존재하기에 기존의 머신러닝 방법으로는 접근하는 것이 어렵다. 이와 같은 문제를 해결하기 위해 Recurrent Graph Neural Network (RecGNN), ConvGNN 등이 등장하였다. RecGNN은 기본적인 GNN 구조와 동일하며 바나흐 고정점 정리(Banach fixed-point Theorem)를 기반으로 hidden state를 반복 계산하는 방식이며, 최초의 Graph Neural Network 역시 RecGNN에 포함된다. 해당 논문에서는 노드의 hidden state를 다음과 같은 재귀적인 방법으로 업데이트 한다.

    \[ h^{(t)}_v = \sum_{u \in N(v)} f(x_v, x^e(v, u), x_u, h^{(t-1)}_u ) \]

    $f$는 Neural Network를 의미하고, $h6^{(0)}_v$는 랜덤하게 초기화된다. 이를 aggregate하여 모든 노드에 GNN을 적용하는 방식으로 진행된다. 위 공식이 수렴하기 위해서는 재귀 함수인 $f$가 축약 사상이 되어야한다는 조건이 있으며, 바나흐 고정점 정리를 사용하여 공식을 수렴한다. 자세한 내용은 여기를 참고하면 된다. 

     

    RecGNN의 경우 바나흐 고정점 정리를 통해 똑같은 수식을 반복적으로 계산하며 최적화하고 있기에 비효율적인 계산방식이고, 또한 반복 시행하는 횟수가 아주 큰 경우에는 고정점의 분포가 비슷한 값들을 가지게 되고, 각 노드의 정보가 구별되지 않기 때문에 최적화하기에 적합하지 않다. 이러한 문제점을 해결하기 위해 GCN이 제안되었다. 

     

     

    Convolution Graph Neural Network (ConvGNN, GCN)

    ConvGNN은 Yan Lecun 교수님의 AlexNet을 시작으로 각광받기 시작한 CNN을 기반으로 노드 또는 엣지의 정보를 계산하는 방법을 의미하며, ConvGNN은 Spatial Convolution Network와 Spectral Convolution Network로 분류할 수 있다. Spatial과 Spectral을 모두 다루기에는 내용이 너무 많아서 이번 글에서는 Spatial Convoltuion Network에 대해서만 다룬다.

     

    Spatial Convolution Network 

    각 노드들의 이웃(neighbor)들로 부터 feature information을 합계하여 hidden representation를 도출한다. feature를 합계한 후 non-linear function(ReLU etc..)을 통과하여 최종 결과를 도출한다. 만약 Layer를 여러 층으로 쌓는다면, 더 멀리 있는 이웃으로부터 메세지를 받아(message passing) 최종 hidden representation을 도출한다. 그렇다면 어떤 방식으로 GCN의 Hidden State를 update할 수 있을까?

    \[ \begin{equation} \begin{split} H^{(l+1)}_2 & = \sigma (H^{(l)}_1W^{(l)} + H^{(l)}_2W^{(l)} + H^{(l)}_3W^{(l)} + H^{(l)}_4W^{(l)} + b^{(l)}) \\ & \\ H^{(l+1)}_i & = \sigma ( \sum_{j \in N(i)} H^{(l)}_j W^{(l)} + b^{(l)}) \end{split} \end{equation} \]

    아래의 수식과 같이 노드 1의 hidden state를 update하기 위해서는 자기 자신의 정보(self-connection)와 연결된 노드들의 정보를 합산하여 사용한다. 

     

     

    Spatial Convolution Network에서 대표적인 모델 중 하나인 Message Passing Neural Network (MPNN)의 구조를 살펴보며 확인해보자.

     

    Message Passing Neural Network (MPNN)

    Message Passing

    Message Passing은 각각의 노드의 Representation Vector를 계속해서 update 해가는 과정을 의미한다. 이웃 노드와 엣지의 Representation Vector를 transforming하고 aggregating하면서 계속해서 노드를 Representation Vector를 update하는 것이다.

    Message Function

    Message Function은 노드에 대한 정보를 얻기 위해 정보들을 aggregate하는 역할을 한다. 노드에 대한 message function은 다음과 같은 형태를 갖는다. 

    \[ m^{t+1}_v = \sum_{w \in N(v)} M_t (h^t_v, h^t_w, e_{vw}) \]

    $h_n$는 노드 $n$에 대한 hidden state, $N(v)$는 노드 $v$의 이웃 집합을 의미하고 $w$는 노드 $v$의 이웃들을 의미한다. $e_{vw}$는 노드 $v$와 노드 $w$를 연결하는 엣지(edge)의 feature 벡터를 의미하고 $M_t$는 모든 정보를 aggregate 하는 message function을 의미한다. 이를 통해 노드 $v$의 다음 message인 $m^{t+1}_v$는 노드 $v$를 표현하는 것이다. 특정 노드에 대한 message를 얻는 과정은 다음 그림과 같다.

     

    cs224w

     

     

    Update function

    Update Function은 위 수식을 통해 구한 특정 노드의 message를 활용하여 노드의 다음 hidden state를 update하는 역할을 수행한다. 노드에 대한 update function의 수식은 다음과 같다.

    \[ h^{t+1}_v = U_t(m^{t+1}_v) \]

    Message Function을 통해 구한 $m^{t+1}_v$ 값과 노드의 현재 hidden state인 $h^t_v$를 고려하여 노드의 다음 hidden state를 update하는 것이다. 

     

     

    Readout

    Readout은 그래프 전체의 노드의 Latent Feature Vector를 평균하여 그래프 전체를 표현하는 하나의 벡터를 생성하는 단계이다. 노드의 Label 혹은 그래프의 Label을 도출할 수 있으며, 수식은 다음과 같다. 

    \[ \hat{y} = \text{R}({h^T_v | v \in \text{G}}) \]

    $h^T_v$는 message passing을 T번 반복하여 생성된 노드 $v$의 hidden state를 의미하고, 이를 Readout Function $\text{R}$의 입력으로 사용하여 원하는 노드의 Label을 도출하는 형태이다. 이같은 구조는 GNN뿐만 아니라 ConvGNN에서도 적용되는 방법이다.