[논문] FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning 설명, 정리
집가고시퍼
2021. 7. 17. 14:31
Introduction
기존에 성공적이였던 Semi-Supervised Learning(SSL) 접근법은 image-based augmentation이였다. 하지만 이는 image space에서의 transformation만 가능하고, 다양한 transformation에서의 dataset의 다른 instance의 지식을 leverage해주지는 못한다. 이를 위해 feature-based refinement와 augmentation을 제안한다. dataset에 있는 다른 image feature에서 추출한 prototype의 작은 set을 이용해 image feature에서 soft-attention을 통해 refine과 augment를 하는 모듈을 사용한다. Memory bank와 k-means clustering을 통해 prototype을 추출하고, 연산을 최소화한다.
정리하자면, 1. Dataset에 있는 모든 class들에 대해 prototype의 small set을 leverage함으로써 abstract feature space에서의 input image feature을 transform하는 refinement & augmentation module을 만들어 준다. 2. Memory bank를 사용해 prototype을 효율적으로 뽑아낸다.
Related Works
1. Consistency Regularization Method : 대부분의 SSL이 이 method 사용. Deep model의 prediction은 같은 data를 semantic-preserving transformation해도 일정해야 한다. feature encoder를 \( f_{x} = Enc(x) \) , classifier을 \( p_{x} = Clf({f_{x}}) \)라 할 때, input image의 pseudo-label은 \(p_{x} = Clf(Enc(x)) \)가 된다. Data augmentation module을 \(AugD(\cdot )\)이라 할 때, x를 augment한 것은 \(\hat{x} = AugD(x)\)이다. 여기서 Consistency loss는 주로 KL-Divergence Loss를 사용하는데, \(\hat{x}\)가 일정한 prediction을 보이도록 한다. \(L_{con} = H(p,Clf(Enc(\hat{x})))\)가 된다.
2. Image-Based Augmentation : Train 과정에서 Data diversity를 늘리고, overfitting을 방지하기 위해 진행.
Feature-Based Augmentation and Consistency
기존의 image-based augmentation은 두가지 한계가 있다. 1. image space에서 작동해서 transformation에 한계가 있다. 2. 하나의 instance안에서 작동하기 때문에 다른 instance의 지식을 참고하여 data transform을 할 수 없다.
다른 class의 지식을 효과적으로 leverage하기 위해서, 각 class에 대해 feature space에서 prototype을 만든다. Image feature는 모든 class의 prototype에서 propagate된 정보를 사용해 refine되고 augment된다. 이를 통해 더 나은 pseudo-label을 만든다. feature refinement와 augmentation은 prototype에 대한 lightweight attention으로 학습되고, 다른 objective(classification loss 등)으로 optimize된다.
Prototype Selection
각 class의 prototype은 매 epoch K-means clustering을 사용해 \(p_{k}\) cluster mean을 추출한다. 하지만 SSL에서는 대부분의 이미지가 unlabeled고, 모든 label이 사용 가능하더라도 연산이 많이 드는 문제가 있다. 이를 해결하기 위해, 각 training loop에서 얻은 feature \(f_{xi}\)와 pseudo label \(\hat{y}\)를 이용한다. recording loop에서, pseudo label과 feature은 computation graph에서 detach되어 memory bank로 들어간다. Feature refinement and augmentation loop는 training loop에서 새로 얻은 정보로 prototype을 update한다.
Fig 2
Learned Feature Augmentation
Learned feature refinement와 augmentation module은 선택된 prototype의 set에 soft-attention을 사용해 만든다. 모듈은 3개의 FC layer로 구성되어 있고, 다른 objective와 optimize되어 있어 classification을 돕기 위해 합리적인 feature-based augmentation을 진행해준다. 각 image feature는 prototype feature을 attention weight(dot product similarity)를 통해 attend한다. Prototype feature는 attention weight에 의해 weight sum되고, residual connect을 통해 input image feature로 fed back되어 feature augmentation과 refinement를 진행한다. Input 이미지의 feature을 추출한 것을 \(f_{x}\), i번째 prototype feature을 \(f_{p,i}\)라 할 때, \(\Phi _{e}\)를 통해 embedding space로 project한다. \(e_{x} = \Phi_{e}(f_{x})\), \(e_{p,i} = \Phi_{e}(f_{p,i})\)이다. 그 뒤, \(e_{x}\)와 \(e_{p,i}\)의 attention weight \(w _{i}\)를 구한다.
가 된다. Softmax는 모든 prototype에 대한 score를 normalize해준다. 이 정보는 이미지에 전달되어 attention weight을 사용한 prototype feature의 합으로 표현된다. $$ f_{a} = relu(\Phi_{a}([e_{x}, \sum_{i}w_{i}e_{p,i}]))\ \ \ \ (2) $$ Fig 3 \(\Phi_{a}\)는 learnable function이고, \([\cdot,\cdot]\)은 feature dimension에서의 concatenation이다.최종적으로, \(f_{x}\)는 residual connection을 통해 \(g_{x} = relu(f_{x} + \Phi_{r}(f_{a}))\)로 refine된다. 간단히 쓰기 위해 \(g_{x} = AugF(f_{x})\)로 쓴다.
Consistency Regularization
Augmentation후엔 unaugmented feature \(f_{x}\)와 augmented feature \(g_{x}\)의 consistency loss를 적용한다. 하지만, \(p=Clf(f)\)라는 classifier를 사용할 때, \(p_{g} = Clf(g_{x})\) 혹은 \(p_{f} = Clf(f_{x})\) 중 어떤 것을 pseudo-label로 사용해야 할지 문제가 생긴다. 하지만 AugF는 input feature를 더 나은 representation으로 refine할 수 있기 때문에, 더 나은 pseudo-label을 만든다. 따라서 \(p_{g} = Clf(g_{x})\)라는 pseudo-label을 사용한다. Feature based consistency loss는 \(L_{con} = H(p_{g},Clf(f_{x})\)로 계산된다. 약하게 augment된 이미지 x와 강하게 augment된 이미지 \(\hat{x}\)를 만든다. x를 이용해 feature-based augmentation과 refinement를 거쳐 더 나은 pseudo-label \(p_{g} = Clf(AugF(Enc(x)))\)을 얻는다. 그리고 두 consistency loss를 적용한다. $$ L_{con-g} = H(p_{g}, Clf(AugF(Enc(\hat{x})))\ \ \ \ (4)$$ $$ L_{con-f} = H(p_{g}, Clf(Enc(\hat{x}))\ \ \ \ (5) $$ 결국 image-based augmentation과 feature-based augmentation 모두에서 L을 사용한다.
Total loss
Consistency regularaization loss \(L_{con_g}\)와 \(L_{con_f}\)는 unlabeled data에 적용된다. 하지만 labeled image(x,y)에 대해서는 regular classification loss를 적용한다. $$ L_{clf} = H(y,Clf(AugF(Enc(x))))\ \ \ \ (6) $$ 결국 total loss는 $$ L_{clf} + \lambda_{g}L_{con-g} + \lambda_{f}L_{con-f} $$ 가 되고, \( \lambda_{g}, \lambda_{f}\)는 \(L_{con-g}\)와 \(L_{con-f}\)의 weight다.