IT 정리용 블로그!

[논문] Contrastive Adaptation Network for Unsupervised Domain Adaptation 설명, 정리 본문

Computer Vision

[논문] Contrastive Adaptation Network for Unsupervised Domain Adaptation 설명, 정리

집가고시퍼 2021. 7. 16. 22:06
  • Contrastive Domain Discrepancy(CDD) : intra-class discrepancy는 줄이고 inter-class discrepancy는 늘린다. \(P(\phi(X_{s})|Y_{s})\)와 \(Q(\phi(X_{t})|Y_{t})\)의 차이를 측정한다. \(D_{H}(P,Q)\). 평균값의 차를 이용. \(\hat{D}^{c_{1},c_{2}}(\hat{y}_{1}^{t},\hat{y}_{2}^{t},...,\hat{y}_{n_{t}}^{t},\phi) = e_{1}+e_{2}-2e_{3}\). \(c_{1}=c_{2}\)일 때는 intra-class discrepancy 측정. Ø는 feature representation. 첫 feature  \(c_{1} \neq c_{2}\)일 때는 inter-class discrepancy 측정. \(e_{1},e_{2},e_{c}\) 식에서 사용하는 \(hat{y}_{i}^{t},hat{y}_{j}^{t}\)꼴은 예측한 결과 사용. 
    ※ \(e_{1},e_{2},e_{3}\)에 사용되는 \(k()\)꼴은 무엇인가..? kernel mean embedding.
    결국 최종적으로 $$ \hat{D}^{cdd} = \frac{1} {M} \sum_{c=1}^{M}\hat{D}^{cc}(\hat{y}_{1:n_{t}}^{t},\phi)(intra) - \frac{1}{M(M-1)}\sum_{c=1}^{M}\sum_{c'=1,c'\neq c}^{M}\hat{D}^{cc'}(\hat{y}_{1:n_{t}}^{t},\phi)(inter)    (5) $$꼴. 결국 (5) 식에서 D^cc는 target domain을 예측한 것들 중 같은 class에 있는 경우. Intra-class discrepancy. \(D^{cc'}\)는 target domain을 예측한 것들 중 다른 class에 있는 경우. Inter-class discrepancy.
    CDD는 MMD를 기반으로 만들어짐. \(\hat{y}_{i}^{t}\)가 noisy하더라도 어느 정도 robust한 결과를 얻게 됨. 

    하지만, 문제점 1 : Unsupervised Domain Adaptation(UDA)에서는 라벨을 모른다. 문제점 2 : 미니배치가 하나의 domain의 샘플만 가질 수 있음. Intra-class discrepancy 측정 불가.

  • Contrastive Adaptation Network(CAN) 제안. Train할 때 labeled source data의 cross-entropy loss만 줄이는 것이 아니라, clustering을 통해 underlying label hypothesis를 예측한다. 그리고 CDD metric에 따라 feature representation을 adapt한다. Clustering 이후에는 cluster center에서 멀리 벗어난 애매한 target data나 center 주위에 target sample이 거의 없는 모호한 클래스들은 CDD를 0으로 만든다.
    이를 위해, train과정에는 sample의 수를 늘리고, mini batch에서 source와 target domain 모두에서 class-aware sampling을 진행한다. 각 class에서 두 domain 모두에서 뽑는다.

    \(S(Source Domain Samples) = \left \{(x_{1}^{s},y_{1}^{s}),…,(x_{N_{s}}^{s},y_{N_{s}}^{s})\right \}\)
    \(T(Target Domain Samples) = \left \{x_{1}^{t},x_{2}^{t}…,x_{N_{t}^{t}\right \}\)
    \(y^{s} \in \left \{0,1,…,M-1\right \}\)(Source는 M classes), \(y^{t} ∈ \left \{0,1,…,M-1 \right \}\). 
    ※ source와 target의 class수가 같아야 한다..? class의 종류도 같아야 한다..?


  • Maximum Mean Discrepancy(MMD) : \(x_{i}^{s}, x_{i}^{t}\)가 sample됐다고 할 때, 이들은 \(P(X_{s}), Q(X_{t})\)에서 샘플된다. 그렇다면 MMD는 \(P = Q\)라는 가설을 받아들일지 여부를 판단한다. 두 분포의 Distribution을 \(D_{H}(P,Q)    (1)\)로 구한다. Layer l에서 봤을 때 MMD는 \(\hat{D}_{l}^{mmd}    (2)\).


  • 최종적으로는 deep cnn에서 여러 fc layer에서의 CDD(unlabeled, target)를 최소화하려 한다(6). 또한, labeled source data를 train하려 한다 cross entropy loss(7). 결국 둘을 모두 사용한 total l이 나온다(8). CAN에서 CDD loss를 줄이는 얘기를 한다.
    \(y_{1:N_{t}}^{t}\)를 최적화 하기 위해, 현재의 feature representation을 고정하고(ex.파라미터 고정), clustering을 이용해 target label을 update한다. 그리고 고정된 \(\hat{y}^{t}\)를 이용해 CDD를 최소화하며, back propagation으로 파라미터를 업데이트한다.
    \(\phi_{1}(\cdot)\)은 sample을 represent할 첫 task-specific layer.
    우선 spherical K-means를 사용해 target sample의 cluster을 만든다. 클러스터의 수는 class의 수인 M이다. Target cluster \(O^{tc}\)는 Source cluster \(O^{sc}\)로 initialize한다. Clustering process는  \(O^{tc}\)와\(phi_{1}(x_{i}^{t}\)의 거리를 최소화하는 \(c\)를 \(\hat{y}_{i}^{t}\)로 넣는다. 그리고 새 \(hat{y}_{i}^{t}\)를 이용해 \(O^{tc}\)를 업데이트한다. Clustering이 끝난 뒤, cluster center에서 멀리 떨어진 data는 버린다. \(D_{0}\) 이하의 거리에 있는 subset T만 사용함. \(O^{tc}\) 주변에 있는 target data가 일정 수 이상 확보되면, 해당 클래스와 데이터를 사용한다. 만약 대표하지 못하는 \(O^{tc}\)나 target data들은 버린다.
    ※ 그러면 Target data에서 누락되는 class가 있지 않나? 그리고 안그래도 데이터가 적은데 데이터를 더 버린다..?


  • 그런데 train하다 보면 train 초기에는 몇 class들이 누락되지만, train이 진행됨에 따라 class들이 포함되게 된다고 합니다. CDD에 의해, intra-class domain discrepancy는 작아지고, inter-class domain discrepancy는 커진다고 한다.
    Te번째 loop에서, selected subset of classes Cte.

  • Class aware sampling(CAS) : CDD에서, minibatch 안에서 한 domain의 sample만 가질 수 있음(source나 target). 이 때문에 랜덤하게 class의 subset인 \(C_{T_{e}}^{'}\)를 \(C_{T_{e}}\)에서 고른다. 그리고 \(C_{T_{e}}^{'}\)의 각 class에서 source data와 target data를 sample한다.
Comments