일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
- 실버
- 그리디 알고리즘
- 설정
- 4948
- end to end
- streamlit
- n과 m
- 기계학습
- Mac
- N-Queen
- 가상환경
- 백준
- 백트래킹
- 9020
- 개발환경
- 1002
- 경사하강법
- pyenv
- 파이싼
- 재귀
- 신경망 학습
- 파이썬
- BOJ
- Python
- 밑바닥부터 시작하는 딥러닝
- 15649
- 손실함수
- 1101
- Today
- Total
파이톨치
[코드리뷰] Falmingo 본문
class Flamingo(nn.Module):
def __init__(
self,
vision_encoder: nn.Module,
lang_encoder: nn.Module,
eoc_token_id: int,
media_token_id: int,
vis_dim: int,
cross_attn_every_n_layers: int = 1,
gradient_checkpointing: bool = False,
):
super().__init__()
self.eoc_token_id = eoc_token_id
self.media_token_id = media_token_id
self.vis_dim = vis_dim
if hasattr(lang_encoder.config, "d_model"):
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
else:
self.lang_dim = lang_encoder.config.hidden_size
self.vision_encoder = vision_encoder.visual
self.perceiver = PerceiverResampler(dim=self.vis_dim)
self.lang_encoder = lang_encoder
self.lang_encoder.init_flamingo(
media_token_id=media_token_id,
lang_hidden_size=self.lang_dim,
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=cross_attn_every_n_layers,
gradient_checkpointing=gradient_checkpointing,
)
self._use_gradient_checkpointing = gradient_checkpointing
self.perceiver._use_gradient_checkpointing = gradient_checkpointing
1. vision_encoder (nn.Module): 이미지 특징을 추출하는 데 사용되는 사전 학습된 비전 인코더 모듈. 일반적으로 CLIP(Contrastive Language-Image Pre-training) 모델이 사용.
2. lang_encoder (nn.Module): 언어 모델링을 수행하는 사전 학습된 언어 인코더 모듈. 일반적으로 GPT(Generative Pre-trained Transformer)와 같은 인과적 언어 모델이 사용.
3. eoc_token_id (int): "<|end of chunk|>" 토큰의 ID. 이 토큰은 시퀀스의 끝을 나타내는 데 사용
4. media_token_id (int): "<image>" 토큰의 ID. 이 토큰은 이미지 입력을 나타내는 데 사용
5. vis_dim (int): 비전 인코더에서 추출된 시각적 특징의 차원. 시각적 특징은 마지막 차원을 따라 이 모양과 일치하도록 투영
6. cross_attn_every_n_layers (int, optional): 트랜스포머 레이어 후에 크로스 어텐션을 적용하는 빈도를 지정
7. gradient_checkpointing (bool): 그래디언트 체크포인팅을 사용할지 여부를 지정
def forward(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None,
clear_conditioned_layers: bool = True,
past_key_values=None,
use_cache: bool = False,
):
assert (
self.lang_encoder.initialized_flamingo
), "Flamingo layers are not initialized. Please call `init_flamingo` first."
assert (
self.lang_encoder._use_cached_vision_x or vision_x is not None
), "Must provide either vision_x or have precached media using cache_media()."
if self.lang_encoder._use_cached_vision_x:
# Case: use cached; vision_x should be cached and other
# vision-related inputs should not be provided.
assert (
vision_x is None
), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first."
assert self.lang_encoder.is_conditioned()
else:
# Case: do not use caching (i.e. this is a standard forward pass);
self._encode_vision_x(vision_x=vision_x)
self._condition_media_locations(input_ids=lang_x)
output = self.lang_encoder(
input_ids=lang_x,
attention_mask=attention_mask,
labels=labels,
past_key_values=past_key_values,
use_cache=use_cache,
)
if clear_conditioned_layers:
self.lang_encoder.clear_conditioned_layers()
return output
캐시된 비전 입력을 사용하는 경우: vision_x가 None이어야 하며, lang_encoder의 _use_cached_vision_x가 True
캐시를 사용하지 않는 경우(일반적인 순전파): _encode_vision_x 메서드를 호출하여 비전 입력을 인코딩하고, _condition_media_locations 메서드를 호출하여 미디어 위치를 컨디셔닝.
그 후, lang_encoder를 호출하여 언어 모델의 출력을 계산.
clear_conditioned_layers가 True일 경우, 컨디셔닝된 레이어를 초기화
* 컨디셔닝(Conditioning)은 머신러닝, 특히 언어 모델에서 사용되는 기술로, 모델이 특정 조건이나 맥락에 기반하여 출력을 생성하도록 학습하는 과정. Flamingo 모델의 경우, 컨디셔닝은 시각적 정보와 언어 정보를 융합하는 데 사용. 모델은 이미지와 텍스트 입력을 받아 시각적 맥락을 고려하여 언어 출력을 생성
def generate(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
**kwargs,
):
num_beams = kwargs.pop("num_beams", 1)
if num_beams > 1:
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
self.lang_encoder._use_cached_vision_x = True
self._encode_vision_x(vision_x=vision_x)
eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id)
output = self.lang_encoder.generate(
input_ids=lang_x,
attention_mask=attention_mask,
eos_token_id=eos_token_id,
num_beams=num_beams,
**kwargs,
)
self.lang_encoder.clear_conditioned_layers()
self.lang_encoder._use_cached_vision_x = False
return output
lang_encoder의 _use_cached_vision_x를 True로 설정하고, _encode_vision_x 메서드를 호출하여 vision_x를 인코딩
'AI&ML' 카테고리의 다른 글
[부스트캠프] 인공지능 기초 다지기 - 딥러닝 - PyTorch (1) | 2024.07.03 |
---|---|
[부스트캠프] 인공지능 기초 다지기 - 기초 수학 - 확률론 (1) | 2024.07.03 |
[부스트캠프] 인공지능 기초 다지기 - 기초 수학 - 벡터 (1) | 2024.07.03 |
[CVPR 2022 tutorial] unified image-text modeling (0) | 2024.04.03 |
[computer vision] Vision Transformer(ViT) Clone Code (1) | 2024.03.23 |