본문 바로가기

Research

[대충 정리] Loss for Active Learning

동영상을 보면 저자가 잘 설명하니.. 동영상 보는 걸 강력 추천한다...

www.youtube.com/embed/9oWC8S2gpyk?start=1

 

 

간단하게만 정리.......실제 논문 설명이 아닌.. 논문, 동영상 보고 난 후 생각을 정리함.

*주의: 논문 내용과 좀 다른 견해가 발생할 수 있음. 

 

1. 컨셉

Active learning은 결국 prediction 후 각 클래스에 대한 확률 값을 confidence로 활용한다. 하지만 이 confidence는 정확할까?? 진짜 정확도를 예측하기 위해 학습 시 발생하는 Loss를 예측하는 레이어를 추가하겠다는 것.

결국 Multi-task learning의 갈래이기도 하다.

 

2. 구조

학습 네트워크들은 보통 반복적인 블록으로 연결되어있다. 그 중간중간의 feature map을 빼와서 정보들을 결합하여 (예측-정답) 사이에서 발생하는 Loss를 또 예측하겠다는 것. 이때 이 Loss를 예측하기 위한 레이어 구성은, 학습에 방해되지 않도록, 즉 기존 레이어들의 학습 영향을 크게 주지 않도록 global average pooling을 우선적으로 사용하고 이것들을 fully connected layer-Relu로 처리한 후 concat 한다. 다시 fully connected 해서 1개의 scalar value인 Loss prediction 결과를 낳는다.

 

3. Loss 계산

 mean square error (MSE) 같은걸 간단하게 생각해볼 수 있는데.. 원래 일반적인 학습방식에서 prediction과 label과의 Loss를 계산하고 그 Loss를 또 동시에 예측하는 loss2를 가정해보자.

 학습이 진행되면서 이 본래 Loss는 점점 수렴하면서 아주 작은 값이 되어갈 테고.. 이 Loss값을 예측하는 loss2는 사실 이걸 줄여가는데만 집중할 뿐.. 어느 이미지에 대한 예측 Loss가 더 높은지 우위 관계에는 관심 없게 된다... (내가 틀릴 수도 있으나 아마 예측하는 loss 값이 고만고만해질 듯)

 Active learning의 컨셉을 생각해보면 많은 unlabeled dataset으로부터 Psuedo Label을 생성하고 그중 자신 없는 데이터를 골라내겠다는 건데.. 학습이 진행되면서 점점 작아지는 loss를 예측하는데만 관심 있고 어떤 데이터에 더 자신 있는지에 대해 판별하는 능력이 증가하지 않게 된다. 그래서 결국 MSE를 이용한 loss prediction은 성능이 다른 method들에 비해 좋지 않았다는 결론을 냈다. 그래서 저자는 ranking loss를 차용했고, 논문에서 수식(2)이 있는데 이걸 정리해보면.. 결국 서로 다른 데이터 A, B에 대해 각각 예측(inference) 값과 라벨을 가지고 Loss를 계산하면, l_i와 l_j가 도출되는데 이것의 우위를 계산하고, 예측 라벨 두 개간의 실제 우위와 다른지를 확인하는 것.

그러면 모델은 학습을 진행하면서 이 우위 차이가 틀리지 않도록 학습하게 된다. 즉 loss값 자체가 점점 줄어드는 걸 따라가는 scale에 초점을 두는 loss2가 아닌, 어떤 게 더 신뢰도가 있는지를 구분하는데 초점을 두게 된다.

 

 

핵심 요약

(1) Active learning에서 자신이 없는(low confidence or uncertain) 데이터를 판별하기 위해 softmax 통과된 proability를 사용하는데, 이건 결국 학습 시 Loss와 연관 지을 수 있고, 이 Loss를 예측하는 Layer를 추가해서 학습단계부터 함께 학습하겠다. (Multi-task learning)

 

(2) Loss를 예측하기 위한 Prediction Layer와 여기서 {예측된 값 <->Loss}에서 발생하는 Loss2는 단순 distance를 계산해서는 잘 학습이 안 된다. 그냥 줄어드는 loss를 따라가는 느낌이고 작은 scale 속에 세밀한 차이를 찾지 못한다. 아무래도 scalar값이다보니.. 전체 image를 아우르기 힘들 듯. 결국 실제 데이터간의 신뢰도 우위를 정하는데 도움이 안된다. 따라서 랭킹 loss를 도입함으로써, mini batch 내에 있는 서로 다른 데이터에서 발생하는 loss들을 비교하는 방식으로 loss2를 학습하겠다.