MIL-NCE Loss의 발견
UniVL(A Unified Video and Language Pre-Training Model) 모델 코드를 들여다보면, 일반적으로 많이 알려진 Cross Entropy, LabelSmoother, Focal 로스와 같이 일반적으로 사용되는 Loss와는 다른 MIL-NCE Loss가 사용되고 있음을 알 수 있다.
UniVL: A Unified Video and Language Pre-Training Model for Multimodal Understanding and Generation
With the recent success of the pre-training technique for NLP and image-linguistic tasks, some video-linguistic pre-training works are gradually developed to improve video-text related downstream tasks. However, most of the existing multimodal models are p
arxiv.org
MIL-NCE Loss의 톺아보기
어떤 Loss인지를 확인하기 위해, 먼저 NCE가 어떤 의미인지 찾아보았다. NCE란 'Noise Constrastive Estimation'의 약자로, Word2Vec에서 사용되었던 방법이다. 일반적으로, NLP에서 주어진 단어 다음으로 올 단어를 예측하는 것은 단어 후보군에 대한 multi class classification을 진행하는 것과 동일하다. 하지만, 이때 다음단어로 올 후보군이 굉장히 많기 때문에 일반적인 multi class classification처럼 학습하기엔 무리가 있다. 이러한 상황에서, NCE는 multi class classfication 태스크를 binary classification 태스크로 변환할 수 있도록 해준다. 이는 기존의 구조에서 거대한 output을 결과로 내기 위해 필요했던 Softmax 연산을 없애고, 로지스틱 회귀에 기반한 연산으로 대체하여 학습에서 필요했던 시간과 컴퓨팅 파워를 절감할 수 있다.
이로써, 학습데이터에서 실제 매칭되는 쌍에 대해선 Positive로 학습하고, 매칭되지 않은 노이즈 쌍에 대해선 Negative로 학습할 수 있는 구조이다. 예를 들어, 고양이 이미지에 대해서 학습한다고 할 때, 고양이 이미지에서 부터 augmentation된 쌍에 대해선 positive-pair로 매칭하고, 다른 동물 이미지에 대해선 negative-pair로 인식한다.
이를 UniVL 텍스트기반 비디오 Retrieval 모델에 접목하면, 같은 시점의 동일한 장면과 장면에 대한 캡션에 대해선 positive-pair로 사용하고, 그렇지 않은 경우에 대해선 negaiver-pair로 인식한다고 볼 수 있다. 이로써, 모델을 binary classification 문제로 정의하여, 각 임베딩간 positive-pair에 대해선 유사도를 높여감과 동시에, negative-pair에선 유사도를 낮춰갈 수 있도록 학습을 진행한다.
이를 NCE Loss로써 수식으로 표현하면 아래와 같다. 배치사이즈의 크기가 n이라고 할 때, n개의 쌍중에 실제 매칭되는 하나의 positive-pair와 n-1개의 negative-pair로 구성된 형태로 학습을 진행하게 된다.
그렇다면 NCE Loss에 추가적으로 MIL은 왜 추가되었을까? MIL은 'Multiple-Instance-Learning'의 약자로, 서로 다른 모달리티를 매칭할 때, 1:1이 아닌 1:N으로 학습을 진행한다. 예를 들어, 이미지와 텍스트를 매칭하고자 할 때, 하나의 클립에 여러개의 텍스트를 매칭하는 것이다. UniVL에서 이렇게 매칭하는 이유는 명확하다. UniVL 모델 Pretrain용으로 활용되는 Howto100M의 경우, 온라인상에서 수집된 영상이어서 노이즈가 존재하기 때문에, 특정 시점의 클립과 클립에 대한 캡션이 완벽하게 매칭되지 않는다. 따라서, 이러한 문제를 해결하기 위해 특정 시점 뿐만 아니라, 특정 시점 이전, 이후에 해당하는 주변 시점의 텍스트도 함께 결합하여 하나의 텍스트 임베딩을 형성하여 매칭을 시킨다.
이러한 MIL-NCE Loss를 표현하면 아래 식과 같다. 위 수식에서 볼 때, 비디오 임베딩 Zv,vt은 단일 시점 임베딩으로 구성되어 있으며, 텍스트 임베딩 {Zt,vt}은 주변 시점의 텍스트 임베딩들이 함께 결합된 형태의 임베딩으로 이뤄져 있다고 볼 수 있다. 이로써, Positive-Pair와 Negative-Pair 각각 하나의 비디오 임베딩에 대해서, 현재 시점과 주변시점의 텍스트가 결합된 임베딩을 기준으로 Positve-pair와 Negative-pair로 구성되어 있음을 알 수 있다.
MIL-NCE Loss 파이썬 코드 구현
class MILNCELoss(nn.Module):
def __init__(self, batch_size=1, n_pair=1,):
super(MILNCELoss, self).__init__()
self.batch_size = batch_size
self.n_pair = n_pair
torch_v = float(".".join(torch.__version__.split(".")[:2]))
self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8
def forward(self, sim_matrix):
mm_mask = np.eye(self.batch_size)
mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair)))
mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device)
from_text_matrix = sim_matrix + mm_mask * -1e12
from_video_matrix = sim_matrix.transpose(1, 0)
new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1)
logpt = F.log_softmax(new_sim_matrix, dim=-1)
mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1)
masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12
new_logpt = -torch.logsumexp(masked_logpt, dim=-1)
logpt_choice = torch.zeros_like(new_logpt)
mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2)
logpt_choice[mark_ind] = 1
sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean()
return sim_loss
[참고문헌]
https://www.kdnuggets.com/2019/07/introduction-noise-contrastive-estimation.html
'Machine Learning' 카테고리의 다른 글
[Knowledge Graph] 지식그래프 구축을 위한 사전지식 (0) | 2022.04.25 |
---|---|
Huggingface에서 AMP를 적용하는 방법 (0) | 2022.04.16 |
2022 국제인공지능대전(AI 엑스포) 방문기 (0) | 2022.04.16 |
BERTAdam 옵티마이저의 진실 (3) | 2022.04.15 |
[CNN에서 DenseNet까지] 컴퓨터 비전 모델 변천사 (0) | 2022.03.21 |