파이톨치

[computer vision] Vision Transformer(ViT) Clone Code 본문

AI&ML

[computer vision] Vision Transformer(ViT) Clone Code

파이톨치 2024. 3. 23. 12:21
728x90

https://kimbg.tistory.com/31

 

[ML] ViT(20.10); Vision Transformer 코드 구현 및 설명 with pytorch

AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE https://arxiv.org/pdf/2010.11929.pdf vit 논문에 관련한 양질의 리뷰는 상당히 많아서, 코드 구현에 관한 설명만 정리하고자 했습니다. 그래도 중

kimbg.tistory.com

## input ##
x = torch.randn(8, 3, 224, 224)
print('x :', x.shape)

patch_size = 16 # 16x16 pixel patch
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', 
                    s1=patch_size, s2=patch_size)
print('patches :', patches.shape)

rearrange는 입력으로 들어간 값을 reshape해주는 것인데 조절이 신기하다. 

여기서 patch_size가 나눌 격자의 가로 개수인데, 224 * 224 이미지가 들어갔을 때, 이를 14 * 16으로 자동으로 나눈다.

여기서 8 3 244 244 => 8 3 (14*16) (14*16)으로 바꾼 다음에, 8 (14 14) (16 16  3)으로 바꾼다. 그러면 배치당 196 크기의 작은 이미지가 생기고 이를 768개 가진다. 

 

CNN 을 사용하면 더 편하게 구현할 수 있다. (근데 저 그림보고 이걸 어케 알아;;)

patch_size = 16
in_channels = 3
emb_size = 768 # channel * patch_size * patch_size 

projection = nn.Sequntial(
	nn.Conv2d(in_channels, emb_size, 
    		  kernel_size=patch_size, stride=patch_size),
              Rearrange('b e (h) (w) -> b (h w) e'))

여기서 in_channels가 RGB 3개를 의미하는 것이고, emb_size가 output_channel을 의미한다. 이 때, 커널 사이즈와 스트라이드를 16으로 지정해서 16크기의 박스가 지나가게 만든다. 이 때 출력의 결과는 batch_size emb_size height weight 인데, 이것을 내가 원하는 결과인 batch_size 196 768로 바꿔준다. 결국 위와 동일해진다. 그러면 196개의 작은 상자가 생긴다. 768은 상자 픽셀 수 * RGB임. 이 768이 embedding이 되는 것이다. 개 신기하다. 작은 이미지 * RGB가 표현이 되는 것이다. 몇개로 짤렸는지는 궁금하지 않다. 내가 자른 사이즈 (patch_size)와 RGB가 합쳐져서 작은 표현이 되는 것이다. 그게 여러개 있는 것인데, 여기서는 224 // patch_size 이니까 14 * 14개라서 196개가 되는 것이다. 

class patchEmbedding(nn.Module):
  def __init__(self, in_channels:int=3, patch_size:int=16,
               emb_size:int=768, img_size:int=224):
    super().__init__()
    self.patch_size = patch_size
    self.projection = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=emb_size,
                  stride=patch_size, kernel_size=patch_size),
        Rearrange('b e (h) (w) -> b (h w) e')
    )
    self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
    self.positions = nn.Parameter(torch.randn((img_size//patch_size)**2+1, emb_size))
  def forward(self, x:Tensor)->Tensor:
    b, _, _, _ = x.shape
    x = self.projection(x)
    # 위에선 배치 사이즈를 모르니까 나중에 배치 사이즈만큼 늘려주는 거구나!
    cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
    x = torch.cat([cls_tokens, x], dim=1)
    x += self.positions
    return x

PE = patchEmbedding()
summary(PE, (3, 224, 224), device='cpu')

여기서 cls_token은 BERT에서 아이디어를 얻은 모양이다. 아까 얻은 표현에 cls_token을 붙여서 입력을 완성해준다. 그러면 batch, 3, 244, 244 -> batch, 197, embedding_size(768)이 되는 것이다.  

 

class MultiHeadAttention(nn.Module):
  def __init__(self, emb_size:int=768, num_heads:int=8, dropout:float=0):
    super().__init__()
    self.emb_size = emb_size
    self.num_heads = num_heads
    # 딥러닝은 기본적으로 합칠때, linear층을 쓰는듯?
    self.qkv = nn.Linear(emb_size, emb_size*3)
    self.att_drop = nn.Dropout(dropout)
    self.projection = nn.Linear(emb_size, emb_size)

  def forward(self, x:Tensor, mask:Tensor=None) -> Tensor:
    qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
    quries, keys, values = qkv[0], qkv[1], qkv[2]
    energy = torch.einsum('bhqd, bhkd -> bhqk', quries, keys)
    if mask is not None:
      fill_value = torch.finfo(torch.float32).min
      energy.mask_fill(~mask, fill_value)
    scaling = self.emb_size ** (1/2)
    att = F.softmax(energy / scaling, dim=-1)
    att = self.att_drop(att)
    out = torch.einsum('bhal, bhlv -> bhav', att, values)
    out = rearrange(out, 'b h n d -> b n (h d)')
    out = self.projection(out)
    return out

 

여기가 이제 트랜스포머 구조를 모르면 아예 이해를 못한다. 여기 이해하면 나머지는 쉽게 이해할 수 있다. 일단 멀티헤드 어텐션 구조는 임베딩 벡터를 쪼개서 각각의 헤드를 지나가게 된다. 그런 후에 linear 층을 지나서 합쳐주게 된다. 이게 트랜스포머 구조의 핵심이다. 또한 qkv구조를 사용한다. 

 

근데 linear층을 왜 3개 안 만들고 하나로 만들고 rearrange하는거지? 귀찮은거 아닌가? 병렬화 때문에 그런가? 

암튼 선형층을 지난 값을 하나씩 quries, keys, values에 넣어준다. 이 때 torch.einsum을 통해서 내적은 해주는 것 같다. (아니 einsum은 뭐 만능이네 그냥) transformer 논문과 마찬가지로 mask를 쓰냐 안 쓰냐에 따라서 mask를 채워준다. 나머지는 그냥 흐름따라 진행됨. 

 

근데 여기서 차원이 어떻게 넘어가는지 머릿속으로 연산을 해야한다. 입력으로 batch, 197, 786이 들어간다. 이것을 잘 나누어 주어야 한다. 786이 임베딩 값인데, 이것을 h d qkv로 나누어 준다. qkv는 3이다. h는 헤드의 개수이다. d는 남는 값. 즉, 768 * 3// 3 // 헤드 수(8) 인데, 96가 된다. 이게 작은 차원이 되는 것이다. 아니 근데 무슨 기준으로 잘라주는거야? (h d qkv) 같은 연산 할 때 어떻게 잘라주는거지? 이게 energy를 softmax해서 어텐션 스코어를 얻는건가? 이걸 안 보고 그냥  슥슥슥 할 수 있어야 할듯.  

 

class FeedFowardBlock(nn.Sequential):
  def __init__(self, emb_size:int, expansion:int=4, drop_p:float=0.):
    super().__init__(
      nn.Linear(emb_size, expansion * emb_size), 
      nn.GELU(), 
      nn.Dropout(drop_p), 
      nn.Linear(expansion*emb_size, emb_size),        
    )

여기서 코딩 스타일 하나 배웠다. 상속 받을 때, nn.Module로 받는게 아니라, nn.Sequntial로 받으면 귀찮게 forward 함수를 작성하지 않고, __init__ 에서 다 끝낼 수 있다. 근데 생각해 봐야하는게, 순차적으로 진행되는 경우만 사용할 수 있는 것 같다. 

class ResidualAdd(nn.Module):
  def __init__(self, fn):
    super().__init__()
    self.fn = fn 
  def forward(self, x, **kwargs):
    res = x 
    x = self.fn(x, **kwargs)
    x += res 
    return x

ResidualAdd에서는 정보를 층을 지나서 넘겨주는 클래스이다.  fn으로는 그냥 nn.Sequential 클래스가 들어가는데, 파이프 라인을 지나서 나온 값이랑 이전 값을 더해주어서 값을 반환해준다. fn에 들어가는 인자들은 뭐가 들어올지 모르기 때문에, **kwargs로 넘겨주게 된다. 그러면 클래스 입장에서는 알바노이다.  

 

class  TransformerEncoderBlock(nn.Sequential):
  def __init__(self, emb_size:int=768, drop_p:float=0.,
               forward_expansion:int=4, forward_drop_p:float=0., 
               **kwargs):
    super().__init__(
        ResidualAdd(nn.Sequential (
           nn.LayerNorm(emb_size), 
           MultiHeadAttention (emb_size, **kwargs),
           nn.Dropout(drop_p) 
        )), 
        ResidualAdd(nn.Sequential(
            nn.LayerNorm(emb_size), 
            FeedFowardBlock(
                emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
            nn.Dropout(drop_p)
        ))
    )

ResidualAdd로 랩핑 해준다. 트랜스포머 논문에서 나온 것처럼 LayerNorm을 지나서 MultiHeadAttention을 지나고, FeedFowardBlock을 지난다. 

 

class TransformerEncoder(nn.Sequential): 
  def __init__(self, depth:int=12, **kwargs):
    super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])


class ClassificationHead(nn.Sequential): 
  def __init__(self, emb_size:int=768, n_classes:int=1000):
    super().__init__(
        Reduce('b n e -> b e', reduction='mean'), 
        nn.LayerNorm(emb_size), 
        nn.Linear(emb_size, n_classes)
    )

트랜스포머 인코더는 위에서 만든 트랜스포머 블록을 12개 쌓아서 만든다. 여기서 *은 인자를 함수에 넘길 때 리스트로 넘기는게 아니라, 리스트 안에 있는 값을 인자 중 하나로 넘긴다. 싱기방기 함. 여기서 Reduce는 잘 모르겠다. 평균으로 합쳣거 줄이는건가? 저기서 n은 내가 쪼갠 이미지의 개수이다. 상식적으로 이걸 합쳐줘야 하는데, 평균을 내서 합쳐주는 모양이다. 오! 이해했다. 

 

class ViT(nn.Sequential): 
  def __init__(self, in_channels:int=3, 
               patch_size:int=16, emb_size:int=768, img_size:int=224, 
               depth:int=12, n_classes=1000, 
               **kwargs):
    super().__init__(
        patchEmbedding(in_channels, patch_size, emb_size, img_size),
        TransformerEncoder(depth, emb_size=emb_size, **kwargs), 
        ClassificationHead(emb_size, n_classes)
    )

summary(ViT(), (3, 224, 224), device='cpu')

 

마지막으로 앞에서 만들었던 클래스들을 다 합쳐준다. patchEmbedding으로 이미지를 쪼개서 이해하기 쉬운 형태로 변환해준다. batch_size, 3, 224, 224 크기의 값이 들어오게 된다면, 이를 batch_size, 197(쪼개진 이미지 수), 786(Embedding)가 된다. 트랜스포머 인코더를 지나면 이 구조가 유지된채로 나온다. batch_size, 197(쪼개진 이미지 수), 786(Embedding)그대로 말이다. 이를 classificationHead에서 batch_size,  786(Embedding)  => batch_size, n_classes로 나오게 된다. 변환하는건 그냥 선형층 지나면 된다. 학습할 때 그걸 잘 학습하나보다 딥러닝 개신기하네. 

728x90