파이톨치

[코드리뷰] Falmingo 본문

AI&ML

[코드리뷰] Falmingo

파이톨치 2024. 4. 2. 10:41
728x90
class Flamingo(nn.Module):
    def __init__(
        self,
        vision_encoder: nn.Module,
        lang_encoder: nn.Module,
        eoc_token_id: int,
        media_token_id: int,
        vis_dim: int,
        cross_attn_every_n_layers: int = 1,
        gradient_checkpointing: bool = False,
    ):
        super().__init__()
        self.eoc_token_id = eoc_token_id
        self.media_token_id = media_token_id
        self.vis_dim = vis_dim
        if hasattr(lang_encoder.config, "d_model"):
            self.lang_dim = lang_encoder.config.d_model  # mpt uses d_model
        else:
            self.lang_dim = lang_encoder.config.hidden_size

        self.vision_encoder = vision_encoder.visual
        self.perceiver = PerceiverResampler(dim=self.vis_dim)
        self.lang_encoder = lang_encoder
        self.lang_encoder.init_flamingo(
            media_token_id=media_token_id,
            lang_hidden_size=self.lang_dim,
            vis_hidden_size=self.vis_dim,
            cross_attn_every_n_layers=cross_attn_every_n_layers,
            gradient_checkpointing=gradient_checkpointing,
        )
        self._use_gradient_checkpointing = gradient_checkpointing
        self.perceiver._use_gradient_checkpointing = gradient_checkpointing

1. vision_encoder (nn.Module): 이미지 특징을 추출하는 데 사용되는 사전 학습된 비전 인코더 모듈. 일반적으로 CLIP(Contrastive Language-Image Pre-training) 모델이 사용.

2. lang_encoder (nn.Module): 언어 모델링을 수행하는 사전 학습된 언어 인코더 모듈. 일반적으로 GPT(Generative Pre-trained Transformer)와 같은 인과적 언어 모델이 사용.

3. eoc_token_id (int): "<|end of chunk|>" 토큰의 ID. 이 토큰은 시퀀스의 끝을 나타내는 데 사용

4. media_token_id (int): "<image>" 토큰의 ID. 이 토큰은 이미지 입력을 나타내는 데 사용

5. vis_dim (int): 비전 인코더에서 추출된 시각적 특징의 차원. 시각적 특징은 마지막 차원을 따라 이 모양과 일치하도록 투영

6. cross_attn_every_n_layers (int, optional): 트랜스포머 레이어 후에 크로스 어텐션을 적용하는 빈도를 지정

7. gradient_checkpointing (bool): 그래디언트 체크포인팅을 사용할지 여부를 지정

 

    def forward(
        self,
        vision_x: torch.Tensor,
        lang_x: torch.Tensor,
        attention_mask: torch.Tensor = None,
        labels: torch.Tensor = None,
        clear_conditioned_layers: bool = True,
        past_key_values=None,
        use_cache: bool = False,
    ):
        assert (
            self.lang_encoder.initialized_flamingo
        ), "Flamingo layers are not initialized. Please call `init_flamingo` first."

        assert (
            self.lang_encoder._use_cached_vision_x or vision_x is not None
        ), "Must provide either vision_x or have precached media using cache_media()."

        if self.lang_encoder._use_cached_vision_x:
            # Case: use cached; vision_x should be cached and other
            # vision-related inputs should not be provided.
            assert (
                vision_x is None
            ), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first."
            assert self.lang_encoder.is_conditioned()

        else:
            # Case: do not use caching (i.e. this is a standard forward pass);
            self._encode_vision_x(vision_x=vision_x)
            self._condition_media_locations(input_ids=lang_x)

        output = self.lang_encoder(
            input_ids=lang_x,
            attention_mask=attention_mask,
            labels=labels,
            past_key_values=past_key_values,
            use_cache=use_cache,
        )

        if clear_conditioned_layers:
            self.lang_encoder.clear_conditioned_layers()

        return output

 

캐시된 비전 입력을 사용하는 경우: vision_x가 None이어야 하며, lang_encoder의 _use_cached_vision_x가 True

캐시를 사용하지 않는 경우(일반적인 순전파): _encode_vision_x 메서드를 호출하여 비전 입력을 인코딩하고, _condition_media_locations 메서드를 호출하여 미디어 위치를 컨디셔닝.

 

그 후, lang_encoder를 호출하여 언어 모델의 출력을 계산.

clear_conditioned_layers가 True일 경우, 컨디셔닝된 레이어를 초기화

 

* 컨디셔닝(Conditioning)은 머신러닝, 특히 언어 모델에서 사용되는 기술로, 모델이 특정 조건이나 맥락에 기반하여 출력을 생성하도록 학습하는 과정. Flamingo 모델의 경우, 컨디셔닝은 시각적 정보와 언어 정보를 융합하는 데 사용. 모델은 이미지와 텍스트 입력을 받아 시각적 맥락을 고려하여 언어 출력을 생성 

    def generate(
        self,
        vision_x: torch.Tensor,
        lang_x: torch.Tensor,
        attention_mask: torch.Tensor = None,
        **kwargs,
    ):
        num_beams = kwargs.pop("num_beams", 1)
        if num_beams > 1:
            vision_x = vision_x.repeat_interleave(num_beams, dim=0)

        self.lang_encoder._use_cached_vision_x = True
        self._encode_vision_x(vision_x=vision_x)

        eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id)
        output = self.lang_encoder.generate(
            input_ids=lang_x,
            attention_mask=attention_mask,
            eos_token_id=eos_token_id,
            num_beams=num_beams,
            **kwargs,
        )

        self.lang_encoder.clear_conditioned_layers()
        self.lang_encoder._use_cached_vision_x = False
        return output

lang_encoder의 _use_cached_vision_x를 True로 설정하고, _encode_vision_x 메서드를 호출하여 vision_x를 인코딩

728x90