Paper review/Graph Mining

Learning Fair Graph Representations via Automated Data Augmentations (ICLR'23)

언킴 2023. 3. 23. 02:14
반응형

Contents

     

     기존 Data Augmentation 방식은 Heuristic하게 적용하기 때문에 도메인에 따라 성능이 달라진다는 문제점이 존재한다. 본 연구에서는 Fairness-aware한 방식을 통해 Data Augmentation을 수행하는 Graphair 기법을 제안하였다. 

     

    Introduction

    Graph Neural Network (GNN) 기법은 Knowledge Graph, Social Media, Molecular Prediction 등과 같은 다양한 분야에서 우수한 성능을 보이고 있다. 그러나, 대부분의 GNN 기법은 인종, 성별 등과 같은 민감한 정보에 따라 다른 예측값을 도출하는 문제가 존재한다. 이와 같은 문제를 해결하고자 Node Feature Masking, Edge Perturbation 등과 같은 Heuristic Data Augmentation 방법을 주로 사용하는데, 이와 같은 방법은 도메인에 따라 서로 다른 결과가 도출된다는 단점이 존재한다.

     

    따라서 본 연구에서는 Heuristic Data Augmentation 방식이 아닌 Automated Data Augmentation 방식인 Graphair 기법을 제안하고자 한다. Graphair의 핵심은 Fairness와 Informativeness를 같이 향상시키기 위해 Adversary Model과 Contrastive Learning 기법을 사용한다. 

     

     

    Background and Related Work

    Fair Graph Representation Learning

    본 연구에서는 그래프 $\mathcal{G} = \{ A, X, S \} $가 주어지면, $A$는 Adjacency Matrix를 의미하고, $X = [x_1, \cdots, x_n ]^T \in \mathcal{R}^{n \times d}$는 Node Feature Matrix를 의미한다. 각 $x_i \in \mathbb{R}^d$는 d 차원의 Feature Vector를 의미하고, $S \in \{0, 1\}^n$는 Sensitive Attributes (Gender or Race)를 의미한다. 

     

    연구의 목적은 Fiar Graph Representation Model $f : (A, X) \rightarrow H \in \mathcal{R}^{n \times d^{\prime}} $을 학습하는 것이며, 학습된 Representation $H = f(A, X)$는 Classification Model $\theta : H \rightarrow \hat{Y} \in \{0, 1 \}^n$의 입력으로 사용된다. 이상적인 Fairness Model $f$은 아래의 기준을 만족하는 결과를 도출하여야 한다.

    \[ \mathbb{P}(\hat{Y}_i | S_i = 0) = \mathbb{P} ( \hat{Y}_i | S_i = 1 ) , \quad i = 1, \cdots, n \]

     

     

    Fairness via Automated Data Augmentations

    본 연구에서는 새로운 그래프 $\mathcal{G}^{\prime} = \{ A^{\prime}, X^{\prime}, S \}$를 생성하기 위한 Automated Data Augmentation Model $g$를 아래와 같이 정의한다.

    \[ T_A, T_X = g(A, X), \quad A^{\prime} = T_A (A), \quad X^{\prime} = T_X(X) \]

    $T_A$는 Edge Perturbation Transformation이며, 이는 Edge를 생성하거나, 제거해 새로운 Adjacency Matrix $A^{\prime}$을 생성한다. $T_X$는 Node Feature Masking Transformation이며, X의 어떤 값을 0으로 변환하여 새로운 Node Feature Matrix $X^{\prime}$을 생선한다. $g_{\text{enc}} : (A, X) \rightarrow Z \in \mathbb{R}^{n \times d_r}$는 먼저 $d_r$ 차원의 Embedding $Z$를 추출하기 위해 사용되고, 그 다음으로 Edge Perturbation, Node feature Masking은 아래와 같은 수식을 통해 계산된다.

     

    먼저 Edge Perturbation은 Embedding $Z$가 주어졌을 때, Multi-Layer Perceptron (MLP) 모델 $MLP_A$를 먼저 계산하여, $Z_A \in \mathbb{R}^{n \times d_{r^{\prime}} }$를 도출한다. 그 다음, Inner-Product Decoder를 통해 Edge Probability Matrix $\tilde{A^{\prime}} \in \mathbb{R}^{n \times n}$을 도출한다. $\tilde{A^{\prime}}$는 각 노드 $i$와 $j$가 연결되어 있을 확률을 나타낸다. 최종적으로, 출력된 Probability Matrix $\tilde{A^{\prime}}$의 확률을 기반으로 Bernoulli Distribution을 통해 최종 Adjacency Matrix $A^{\prime}$가 계산된다.

    \[ Z_A = MLP_A (Z), \quad \tilde{A^{\prime}} = \sigma (Z_A Z^T_A ), \quad A^{\prime}_{ij} \sim \text{Bernoulli} (\tilde{A^{\prime}}_{ij}) \]

    \[\ \text{for} \ i, j = 1, \cdots, n \]

     

    Node Feature Masking은 Embedding $Z$가 주어졌을 때, $MLP_B$를 먼저 계산하여 Mask Probability Matrix $\tilde{M} \in \mathbb{R}^{n \times d}$를 도출한다. 그런 다음, Edge Perturbation과 마찬가지로 Bernoulli Distribution을 통해 해당 Node의 Masking 여부를 나타내는 Masking Matrix $M_{ij}$를 도출하고, Hadamard Product를 통해 새로운 Node Feature Matrix $X^{\prime}$을 도출한다.

    \[ Z_X = MLP_X (Z), \quad \tilde{M} = \sigma (Z_X), \quad M_{ij} \sim \text{Bernoulli} (\tilde{M}_{ij}) \]

    \[ \ \text{for} \ i, j = 1, \cdots, n, \quad X^{\prime} = M \odot X \]

     

    그러나, Bernoulli Distribution과 Mask Matrix $M$은 Discrete한 값을 가지고 있기 때문에 미분이 불가능(Non-differentiable)하다. Augmetnation Model $g$를 End-to-End로 학습하기 위해 일반적으로 사용하는 Gumbel-Softmax Reparameterization Trick을 사용한다. 이는 Discrete한 값을 Continuous 하도록 만들어주는 Trick이며, 아래와 같은 수식으로 계산된다.

    \[ \hat{P} = \frac{1}{1 + \text{exp} ( - ( \log \tilde{P} + G ) / \tau ) }, \quad G \sim \text{Gumbel}(0, 1) \]

    $G \sim \text{Gumbel}(0, 1)$은 Standard Gumbel Distribution에서 샘플링된 Random Variable을 의미한다. 위와 같은 Trick을 사용하면, 샘플링된 $G$에 따라서 $\hat{P}$의 값이 달라지며, 이를 통해 Continuous 하도록 만들어준다. Forward Propagation을 수행할 때는 Discrete Value $P = \lfloor \hat{P} + \frac{1}{2} \rfloor$를 통해 값을 계산하고, Backward Propagation을 수행할 때는 $\nabla_{\varphi} \hat{P} \approx \nabla_{\varphi} P $를 통해 근사한다. 해당 방식은 선행 연구를 바탕으로 진행한 방식이다.

     

    Adversarial Training

    본 연구에서의 목적은 편향을 줄이기 위해 Fair Agumentation을 생성하는 것이다. 따라서, Augmentation Model $g$는 "Fairness"를 만족하여야 한다. 다시 말하자면, Prediction Bias를 발생시키는 Edge, Node Feature에 낮은 확률을 부여하여야 한다. 그러나, 어떤 Edge, Node Feature가 Prediction Bias를 발생시키는지에 대한 정답 라벨인 Ground Truth가 없기 때문에 Supervised Learning을 통해 학습하는 것이 불가능하다. 따라서, 본 연구에서는 해당 이슈를 다루기 위해 Adversarial Learning 방법을 사용한다. Adversarial Model $k : (A^{\prime}, X^{\prime} ) \rightarrow \hat{S} \in [0, 1]^n$는 $A^{\prime}$과 $X^{\prime}$를 통해 Sensitive Attribute $S$를 예측하는 모델이다. $g$는 둘 간의 차이가 없도록, $k$는 0과 1을 정확하게 예측하도록 하는 방식과 동일하다. 각각 Generator와 Discriminator의 역할을 한다고 볼 수 있다.

    \[ \underset{g}{\text{min}} \ \underset{k}{\text{max}} \ L_{\text{adv}} = \underset{g}{\text{min}} \ \underset{k}{\text{max}} \frac{1}{n} \sum^n_{i=1} [ S_i \log \hat{S}_i + (1 - S_i) \log (1 - \hat{S}_i ) ] \]

    이때 본 연구에서 사용하는 Adversary Model과 Augmentation Encoder $g$는 Two-Layer GCN Model을 사용하였다.

     

    Contrastive Learning

    마지막으로 Contrastive Learning을 사용하였다. 이번 단계는 Node Representation $ H = f(A, X)$과  Augmetation된 $H^{\prime} = f ( A^{\prime}, X^{\prime} ) $을 바탕으로 Contrastive Learning을 수행한다. $(h_i, h^{\prime}_i)$는 Positive Pair로 설정하고, $(h_i, h_j)$ 그리고 $(h_i, h^{\prime}_j)$는 Negative Pair로 설정하였다.

    \[ l(h_i, h^{\prime}_i) = - \log \frac{ \exp (\text{sim} ( h_i, h^{\prime}_i ) / \tau ) }{ \sum^n_{j=1} \exp (\text{sim} (h_i, h^{\prime}_j ) / \tau ) + \sum^n_{j=1} \mathbb{1}_{[j \neq i ]} \exp ( \text{sim} (h_i, h_j) / \tau ) } \]

    최종 Contrastive Learning의 Loss Function은 $(h_i, h^{\prime}_i)$와 $ (h^{\prime}_i, h_i)$의 Positive Pair에 대한 Loss로 계산되기 때문에 두 값의 평균을 사용한다.

    \[ L_{\text{con}} = \frac{1}{2n} \sum^n_{i=1} [ l(h_i, h^{\prime}_i ) + l (h^{\prime}_i, h_i ) ] \]

    본 연구에서는 Augmentation Model $g$가 입력 그래프에서 너무 벗어나는 값을 가지는 것을 방지하기 위해 Reconstruction-based Regularization Term을 설정해주었다.

    \[ \begin{equation} \begin{split} L_{\text{reconst}} & = L_{\text{BCE}} (A, \tilde{A^{\prime}}) + \lambda L_{\text{MSE}} (X, X^{\prime}) \\ \\ & = - \sum^n_{i=1} \sum^n_{j = 1} \left [ A_{ij} \log (\tilde{A^{\prime}}_{ij} ) + ( 1 - A_{ij} ) \log (1 - \tilde{A^{\prime}}_{ij} ) \right ] + || X - X^{\prime}||^2_F \end{split} \end{equation} \]

    위 전체 과정을 종합하여 최종 Loss Function은 아래와 같이 정의할 수 있다.

    \[ \underset{f, g}{\text{min}} \ \underset{k}{\text{max}} \ L = \underset{f, g}{\text{min}} \ \underset{k}{\text{max}} \ \alpha L_{\text{adv}} + \beta L_{\text{con}} + \gamma L_{\text{reconst}} \]

     

     

    Experiments

    본 연구에서는 Fiarness와 Informativeness에 대한 성능을 검증하기 위해 Accuracy와 Demographic Parity $\Delta_{DP}$와 Equal Opportunity $\Delta_{EO}$를 평가지표로 사용하였으며, $\Delta_{DP}$와 $\Delta_{EO}$는 낮을수록 Fairness하다는 것을 의미한다. 

    \[ \Delta_{DP} = | \mathbb{P} (\hat{Y} = 1 | S = 0 ) - \mathbb{P}(\hat{Y} = 1 | S = 1)| \]

    \[ \Delta_{EO} = | \mathbb{P}(\hat{Y} = 1 | S = 0, Y = 1) - \mathbb{P} (\hat{Y} = 1 | S = 1, Y = 1 )| \]

    Baseline Model은 Fairwalk, Fairwalk + Node Feature, GRACE, GCA, NIFTY, FairDrop, FairAug를 사용하고, NBA, Pokec_z, Pokec_n 데이터를 사용하였다. 실험은 1) Experimental Results, 2) Ablation Studies, 3) Analysis of Fair View 총 3가지로 구성하였다. 먼저, Experimental Results에서는 Fairness와 Accuracy Performance는 Trade-off 관계에 있다는 것을 시각화하고, 본 연구에서 제안하는 Graphair 기법과 다른 Graph 기반 기법 간의 성능 및 Fairness를 비교분석한 실험이다 (Figure 2, Table 1). Ablation Studies는 본 연구에서 제안하는 Graphair 기법 내에 Feature Mask (FM), Edge Perturbation (EP)를 제거했을 때 모델의 성능을 비교분석한 실험이다 (Table 2). Analysis of Fair View는 Node Sensitive Homophily를 보여주고 Node 간의 Spearman Correlation을 보여준다. 기존 선행 연구에 따르면 Node Sensitive Homophily가 높을수록 Prediction Bias가 발생한다고 한다. 본 연구의 기법은 다른 기법에 비해 Node Sensitive Homophily가 낮은 부분에 Density가 높기 때문에 상대적으로 Prediction Bias가 낮다고 볼 수 있다.