일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 | 31 |
- REST API
- N-Queen
- 백준
- 파이썬
- streamlit
- 밑바닥부터 시작하는 딥러닝
- 손실함수
- end to end
- 티스토리챌린지
- 개발환경
- BOJ
- Retrieval
- 9020
- 경사하강법
- pyenv
- n과 m
- Python
- 그리디 알고리즘
- 재귀
- 1101
- 백트래킹
- 오블완
- video retireval
- 1002
- 신경망 학습
- 15649
- 기계학습
- 파이싼
- 가상환경
- 4948
- Today
- Total
파이톨치
CLIP 코드 뜯어보기 - Transformer 만들고 ViT로 확장하기 본문
https://github.com/openai/CLIP/blob/main/clip/model.py
CLIP/clip/model.py at main · openai/CLIP
CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image - openai/CLIP
github.com
이 코드를 참조해서 작성했다.
처음 시작은 attention block을 만드는 것부터 시작한다. 여기서는 CNN을 사용하지 않고, 처음부터 입력 차원을 조절해줄 것이기 때문에 nn.MultiheadAttention 클래스를 그대로 사용하며, 그렇기에 코드도 더 간단하게 나온다. 입출력은 뒤에서 맞추어 주리고 하고, 여기서 해야할 것은 간단하다. x 값에 대한 Attention 의 출력을 Residual 하게 더해주면 된다.
forward 함수를 생각해봤을 때, x = x + self.attention(x)를 하면 될 거 같은데, 부가적으로 들어가는 것들이 있다.
예를 들어서, LayerNorm을 통해서 좀 더 학습 안정성을 높여주고, mlp 층도 residual하게 연결을 해주는 모습이 보인다.
그리고 저기서 사용한 attention 함수에는 attn_mask를 device로 옮겨주는 모습을 볼 수 있다. 확장성을 위해서 저렇게 작성한 모양이다. 그리고 multiheadAttention의 출력으로는 x, attn_mask가 나오니까. [0]을 해주는 모습을 볼 수 있다.
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
참고로 nn.MultiheadAttention의 입력 텐서는 (L, N, E)인데, 시퀀스 길이, 배치 크기, 임베딩 차원을 의미한다.
그래서 결과적으로 이렇게 되는 클래스라고 이해할 수 있을 것이다. 중간에 배치 정규화가 들어가는 디테일도 있다.
이걸 이제 조각 조각 이어 붙이게 되는데, layer 만큼 이를 복제해서 붙이고 nn.Sequntial 함수로 감사주면 끝이다.
이렇게하면 그 유명한 트랜스포머가 생기는 것이다.
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
이제 이걸 ViT -> vison transformer로 확장을 해야 하는 것인데, 방법은 간단하다. 이미지를 잘라서 저 형태에 맞게 넣어주면 된다.
근데 문제가 하나 생기는데 결국 그러면 출력도 L, N, E 그대로 나온다. 우리가 필요한 것은 N, E 이다. N, L, E로 만들어 둔 다음에 N, E로 만들면 된다. 이 코드에서 사용한 방식은 x[:, 0, :]을 하는 방식인데, L 중에서 첫번째만 사용하겠다는 의미이다. 그렇게 하면 N,E 차원이 나오게 된다.
이를 코드로 구현하면 다음과 같다.
처음에는 입력 이미지를 처리해준다. 이는 간단하게 Conv2d로 처리하는데 방식은 다음과 같다.
patch 크기의 커널이 stride도 patch로 해서 원본 이미지를 patch 크기만큼 나눈 것이 입력 해상도가 되는 것이다.
H -> H // patch_size 그리고 이걸 주석에서는 grid로 쓴다.
그리고 차원은 3차원에서 width 차원이 된다.
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
여기서 잘 이해가 안되는게, 이렇게 첫번째 하나만 떼오는게 도움이 되나?
아! 첫번째 떼오는게 도움이 된다. 왜냐면 저게 cls 토큰 이니까!!
즉 트랜스포머를 거치고 나온 cls 토큰만 임베딩 벡터로 쓰겠다는 얘기다!!!