GNN의 원리 이해해보기
Updated:
GNN의 원리를 이해해보자.
Graph??
- 많은 데이터가 가지고 있는 구조
- 대표적인 것들 : Social Graph, Moelcular Graph
- node와 edge가 연결되어 있는 구조
- social graph를 예로 들면 사람이 node고 관계가 edge가 될 수 있음
이러한 정보를 가진 graph를 neural network에서는 어떻게 사용할까요?
Adjaceny Matrix & Feature Matrix
아래와 같은 관계를 가지는 그래프 구조가 있다고 가정해봅시다.
- 1은 노란색이며 2와 연결되어있다.
- 2는 파란색이며 1,3,4와 연결되어있다.
- 3은 파란색이며 2와 연결되어있다.
- 4는 노란색이며 2와 연결되어있다.
이를 matrix로 표현할 때 node간의 관계(edge)를 표현하는 adjaceny matrix와 node에 대한 담긴 정보를 나타내는 feature matrix가 필요하다.
$ Adjaceny matrix A = \begin{bmatrix} 0 & 1 & 0 & 0 \ 1 & 0 & 1 & 1 \ 0 & 1 & 0 & 0 \ 0 & 1 & 0 & 0 \ \end{bmatrix} $
$ Feature matrix F = \begin{bmatrix} 0 & 1 & 1 \ 1 & 0 & 1 \ 1 & 0 & 1 \ 0 & 1 & 1 \ \end{bmatrix} $
- A의 각 성분 $A_{i,j}$ 는 i번째 node와 j번째 node가 연결되어 있는지를 보여준다.(연결되면 1, 아니면 0.) 따라서 node의 개수만큼 행과 열을 가지는 matrix이다.
- 이런 그래프는 edge에 방향이 있지 않아서 undirect graph라고 하며 A는 symmetric함
- edge에 방향이 ㅇㅆ는 그래프는 directed graph라고 하며 A가 symmetric하지 않음
- F의 각 성분 $F_{i,j}$ 는 i번째 node의 j번째 feature는 무엇인지 보여진다. 1번과 4번 node는 같은 색이고 2와 3번 node가 같은 색이므로 같은 feature를 가진다고 생각하여 행렬을 표현하였다. 각 node는 3개의 feature를 가지고 있다고 가정하여 열(col)은 3이고 node는 4개이므로 행(row)는 4개의 구조를 가지는 matrix다.
Graph Convolution Network(GCN)
그래프는 edge와 node 간의 연결을 이루는 구조이므로 edge로 연결된 node끼리의 정보 교환이 중요하다는 것을 알 수 있다. 이를 위해 graph structure를 고려하여 연산을 하는 방법이 필요한데 그 중 하나가 convolution을 이용한 방법이다. convolution이란 정배열된 node간의 정보를 spatial하게 얻어서 사용하는 방법이고 이러한 특성때문에 Image 계열에서 많이 사용된다. 대표적인 특징 3가지는 아래와 같다.
- Spatial feature extraction : 주변 node의 정보를 가져와 feature를 추출한다.
- Weight Sharing : weight를 각 영역별로 따로 사용하는게 아니라 공유하여 사용한다.
- Translation invariance : image의 경우 그림을 조금 shift하더라도 같은 결과를 가져오게 한다.
이러한 특성을 가져와서 그래프에서도 주변 node의 정보를 spatial하게 얻어 정보를 업데이트를 하는 방법을 고안한 것이 graph convolution network(gnn)이다. 그리고 실제로는 자기자신의 정보도 중요하므로 자신의 정보도 함께 사용하서 업데이트함을 기억하자.(A의 대각성분이 1이 된다.) 실제 연산은 AFW matrix 연산으로 이루어진다.(A:Adjacency matrix, F:Feature matrix, W:Weight matrix)
아까 전에 wegiht matrix는 sharing한다고 말하였다. 하지만 F와 W만을 이용해서는 이런 spatial한 feature를 뽑지 못하여 A matrix를 사용한다. Weight sharing을 사용해 업데이트 된 matrix와 A matrix를 연산하게 되면 각 node의 feature만을 이용하여 다음 업데이트를 하게 된다.
추가로 A matrix와 Weight matrix를 여러 번 반복해서 사용하게 되면 직접 연결된 node의 정보뿐 아니라 건너 연결된 node의 정보까지 이용할 수 있다. 따라서 이런 식으로 여러 번 반복하면 convolution처럼 receptive field가 커지는 효과를 얻을 수 있다.
Graph pooling & message passing
GNN에서 정보를 업데이트하기 위한 방법 중 pooling과 message passing에 대해 알아봅시다.
Graph pooling
CNN과 비슷하게, 복잡한 구조의 graph structure를 단순하게 만드는 데 사용할 수 있다. max 값을 추출하거나 average 값을 추출할 수 있다. 다른 한편으로는 edge나 node의 정보를 서로 전달하고자 할 때도 사용할 수 있다. Social graph를 예로 들어 보면, 각각의 사람들에 대한 정보가 node에 담겨 있고 edge로 node가 연결되어 있다고 할 때, 사람들의 정보를 보고 관계가 좋은지 나쁜지 binary classification(edge prediction)을 하고 싶다고 해보겠다. 이때 pooling을 이용하여 다음의 단계로 prediction을 수행할 수 있다.
- Edge별로 연결된 node의 정보(features)를 모은다.
- 정보(features)를 합쳐서 edge로 보낸다. (pooling)
- Parameter를 이용해서 classification을 수행한다.
Message passing
Message passing은 node 혹은 edge 간의 연관성을 고려하면서 feature를 업데이트하기 위한 방법이다. 예를 들어 node를 주변 node 정보를 이용해서 업데이트하고 싶을 때 message passing은 다음과 같이 이루어진다.
- Edge로 연결되어 있는 node의 정보(features, messages)를 모은다.
- 모든 정보를 aggregate function (sum, average 등)을 이용하여 합친다.
- Update function(parameter)을 이용해서 새로운 정보로 업데이트한다.
Message passing을 여러 번 반복하여 receptive field를 넓힐 수 있고 더 좋은 representation을 얻을 수 있다.
Leave a comment