파이톨치

[BoostCamp AI Tech] CLIP 개념부터 코드까지 살펴보기 본문

AI&ML/BoostCamp AI Tech

[BoostCamp AI Tech] CLIP 개념부터 코드까지 살펴보기

파이톨치 2024. 9. 5. 15:04
728x90

CLIP

Multi Modal Model 

최근 LLM에 이어서 LMM이 유행이다. 이것은 Large Multimodal Model의 약자이다. 

하지만, 사람들은 아직 LMM에 대해 잘 알지 못한다. 

 

멀티 모달이라고 하는 것은 하나의 인지분야만 사용하는 것이 아니다. 

예를 들어, 시각, 청각, text에 대한 이해는 모두 하나에 대한 모달리티이다. 

우리의 목적은 이것들을 여러개 사용하겠다는 것이다. 

 

이러한 적용은 대표적으로 text to image 모델들이 있다. 

텍스트를 넣으면 그에 맞는 이미지가 나오는 디퓨전 모델이다. 

 

이러한 멀티 모달 분야에는 Maching, Translating, Referencing 방식의 학습이 있다.

CLIP 모델의 경우 Matching 형태로 학습이 된다. 

(플라밍고 모델의 경우 referencing 형태일 것이다.) 

CLIP Concept

CLIP은 Contrastive Language-Image Pre-training의 약자이다. 그렇다면 추론해 볼 수 있다. 

 

1. 일단 대조학습을 사용하는 모델이다.

2. 이미지와 text에 대해 멀티 모달을 하는 모델이다. 

3. Pre-trained된 모델이다. 

 

1번부터 천천히 보자. 대조학습을 위해서는 네거티브 pair를 가지고 있어야 한다. 

예를 들어, 개와 개에 대한 이미지를 가진다. 그에 맞는 네거티브 pair는 비행기나 컴퓨터가 된다. 

대조학습의 목적은 postive에 대해 유사한 벡터를 가지게 하고 negative에 대해 벡터 유사도가 낮아져야 한다. 

이런 식으로 말이다. 하지만 실제로는 이미와 텍스트가 같은 임베딩 공간에 있진 않을 것이다. 

그렇다면 어떻게 해야 위와 같은 결과를 얻을까? 

 

CLIP 학습

이제 2번을 볼 차례가 왔다. 이미지와 text를 같이 학습하는 멀티 모달이 어떻게 학습하는지 말이다. 

이 벡터들을 Cross Entropy Loss로 학습시킨다. 위에서 본 대조학습도 결국 Cross Entropy의 변형이다. 

 

CLIP은 이미지와 텍스트 쌍을 입력으로 받아서, 두 가지 다른 인코더(이미지 인코더와 텍스트 인코더)를 통해 각각 임베딩을 생성합니다.

이때의 학습 목표는:

  • 같은 이미지-텍스트 쌍은 서로 가까운 임베딩을 가지도록.
  • 다른 이미지-텍스트 쌍은 서로 멀리 떨어진 임베딩을 가지도록.

나는 수학을 별로 안 좋아하니까 코드로 보자. 코드로 봐도 어렵다. 

저 logit 부분이 위에 있는 2차원 matrix일 것이다. 예시를 들어 더 잘 이해해보자. 

좀 더 간단한 형태로 이미지 3개와 텍스트 3개가 있다고 가정해보자. 그리고 그 임베딩 벡터끼리 유사도를 구한 matrix이다. 

정답 레이블 labels

정답 레이블 labels는 각 이미지가 올바르게 매칭되어야 할 텍스트의 인덱스를 나타냅니다.

여기서는 이미지 0 ↔ 텍스트 0, 이미지 1 ↔ 텍스트 1, 이미지 2 ↔ 텍스트 2이므로:

labels=[0,1,2]

 

이미지-텍스트 손실

loss_i = cross_entropy_loss(logits, labels, axis=0)

목표: 각 이미지 ii가 올바른 텍스트 j=i를 찾도록 학습합니다.

 

손실 계산 과정 예시:

  • 이미지 0에 대해:
    • logits[0, :] = [0.9,0.2,0.1]
    • 정답 레이블은 텍스트 0이므로, 정답 텍스트 0의 확률이 높도록 학습해야 합니다.
    • 크로스 엔트로피는 정답 레이블에 해당하는 확률이 높아지도록 손실을 계산합니다.
  • 이미지 1에 대해:
    • logits[1, :] = [0.3,0.8,0.4]
    • 정답 레이블은 텍스트 1이므로, 텍스트 1의 확률이 높아지도록 손실이 계산됩니다.
  • 이미지 2에 대해:
    • logits[2, :] = [0.2,0.1,0.7]
    • 정답 레이블은 텍스트 2이므로, 텍스트 2의 확률이 높아지도록 손실이 계산됩니다.

텍스트-이미지 손실 loss t

loss_t = cross_entropy_loss(logits, labels, axis=1)

 

  • 텍스트 0에 대해:
    • logits[:, 0] = [0.9,0.3,0.2]
    • 정답 레이블은 이미지 0이므로, 정답 이미지 0의 확률이 높도록 학습해야 합니다.
  • 텍스트 1에 대해:
    • logits[:, 1] = [0.2,0.8,0.1]
    • 정답 레이블은 이미지 1이므로, 이미지 1의 확률이 높아지도록 손실이 계산됩니다.
  • 텍스트 2에 대해:
    • logits[:, 2] = [0.1,0.4,0.7]
    • 정답 레이블은 이미지 2이므로, 이미지 2의 확률이 높아지도록 손실이 계산됩니다.

최종 손실 계산

loss = (loss_i + loss_t)/2

 

최종 손실은 이미지-텍스트 손실 loss i와 텍스트-이미지 손실 loss t평균하여 계산합니다. 이렇게 하면, 모델이 이미지텍스트 모두 올바른 매칭을 찾는 데 효과적으로 학습될 수 있습니다.

CLIP 응용

CLIP은 결국 사전학습된 모델이며, 중요한 것은 이를 활용하는 방법이다. 

Image Captioing 

 

clip은 많은 양의 데이터를 때려 넣어서 학습했기 때문에 좀 더 일반적인 captioning을 할 수 있다. 단순히 clip만 사용하는 것이 아니라, LM과 함께 사용한다.

 

일종의 Loss로 사용하는 경우가 많은 것 같다. CLIP이 벡터 공간을 더 잘 표현하게 해주는 모양이다. 

CLIP 코드 살펴보기 

여기부터 토할거 같다. 너무 방대해서 흐름을 잃을 것 같다. 하지만, 천천히 보자. 모르는 부분은 상상으로 메꿔야 한다. 

처음에 알아야 하는 코드 블록은 nn.MultiheadAttention이다. 

torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, 
				add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, 
               			batch_first=False, device=None, dtype=None)

 

이런 식으로 구성된 블록이다. torch 코드를 볼 때는 forward도 함께 봐야 어떤 코드인지 이해할 수 있다. 

forward(query, key, value, key_padding_mask=None, need_weights=True, 
	attn_mask=None, average_attn_weights=True, is_causal=False)

아하! 이것은 query, key, value를 사용하는 어텐션 구조이며, 값을 모두 동일하게 넣어주면 self-attention이 되는 것이다. 

class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head) ### returns [0]: multihead output [1]: multihead output weight
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        ##### IMPLEMENT HERE
        return attn

    def forward(self, x: torch.Tensor):
        ##### IMPLEMENT HERE
        return x, weights

그렇다면, 저기에 들어갈 코드는 먼저 attention부터 생각해보자. 주석을 보면 반환되는 값이 list 같다. 그 값을 그대로 넘겨준다. 

attn = self.attn(x, x, x, attn_mask=self.attn_mask)

가 될 것이다. (self attention 밖에 할게 없다. 주어진 값이 x 밖에 없잖아요.) 

그 아래 코드는 forward이다. 위에서 작성된 층들을 고려해서 써보자. 

x, weights = self.attn(x)
x = self.ln_1(x)
x = mlp(x)
x = self.ln_2(x)

가 될 것이라고 생각된다. 

class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

    def forward(self, x: torch.Tensor):
        weights_all_blocks = []

        for block in self.resblocks:
            ###### IMPLEMENT HERE

        return x, torch.stack(weights_all_blocks)

위에서 이어서 트랜스포머 코드를 살펴보자. 또 생각해보자 weight를 어떻게 구해야 할까? 

지금 순회하는 것은 위에서 구현한 ReisualAttentionBlock이다. 이 층을 설계할 때 forward를 할 때, weight를 반환하게 만들었다. 

그렇다면 단순히 forward를 할 때 마다 값을 저장하게 만들어주면 된다. 

위 트랜스포머는 엄밀하게 말하면 BERT 형태라고 생각된다. 

x, weights = block(x)
weights_all_blocks.append(weights)

가 들어가면 좋을 것 같다. 

결국 저 트랜스포머 구조가 이미지 인코더 구조가 된다. 그와 동시에 텍스트 인코더 구조가 된다. 

class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)

        self.transformer = Transformer(width, layers, heads)

        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x, attn_output_weights = ###### IMPLEMENT HERE

        x = x.permute(1, 0, 2)  # LND -> NLD

        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj

        return x, attn_output_weights

cnnv1을 통해서 이미지를 패치 단위로 잘라서 사용한다. 

위치 정보는 nn.Parameter로 넣어주는듯한데, 이것으로 이미지 정보끼리 상대적인 위치 정보를 학습하는 것 같다. 

클립은 이제 두 모델을 사용해서 임베딩 벡터를 뽑고 위에서 배운 대조학습을 해준다. 

간략화된 코드를 살펴보자. 

class CLIP(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 vision_patch_size: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int
                 ):
        super().__init__()

    @property
    def dtype(self):
        return self.visual.conv1.weight.dtype

    def forward(self, image, text):
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)

        # normalized features
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        # shape = [global_batch_size, global_batch_size]
        return logits_per_image, logits_per_text

 

@property 데코레이터

  • @property는 메서드를 속성처럼 사용할 수 있게 해주는 파이썬의 데코레이터이다.
  • model.dtype == self.visual.conv1.weight.dtype이 된다.

forward

생략했지만, 이미지와 텍스트를 벡터화 시키는 매서드를 통해서 feature값을 얻는다. 

그래서 그냥 내적해서 반환한다. 이렇게 보니 그냥 간단한 구조다. 

그 뒤로 학습할 때 위에서 정리한 대조학습을 사용하는 모양이다. 

 

728x90