Transfer learning을 수행해야할 때,
torch의 hub를 활용해서 기존의 pretrained model을 가져오는 경우가 많이 있다.
그러나, 단순히 마지막 fully connected layer만을 없애고 싶은게 아니라,
중간의 feature부터 활용하고 싶은 경우가 있는데, 이런 경우는 forward 함수를 건드리면 제일 간편하다.
예를 들어, vision transformer에서, 마지막 cls token의 값을 가져오는게 아닌
patch의 정보를 가져오고 싶을때?
단순히 모델의 architecture를 수정한다고 해결할 수 있는 문제는 아니다.
forward 함수에서, cls token만 짚어서 return하고 있기 때문이다.
이런 경우, python의 inspect를 활용하면 매우 간편하다.
import torch
import inspect
from torchvision.models import VisionTransformer, vit_b_16, ViT_B_16_Weights
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()
ln [4] forward_source = inspect.getsource(model.forward)
ln [5] print(forward_source)
def forward(self, x: torch.Tensor):
# Reshape and permute the input tensor
x = self._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = self.encoder(x)
# Classifier "token" as used by standard language architectures
x = x[:, 0]
x = self.heads(x)
return x
이제 forward 함수를 직접 확인할 수 있으니,
여기서 cls token만 가져오는 x[:, 0] 을 수행하지 않고 x를 바로 return하면 patch의 feature를 가져올 수 있다.
'Major Study. > Computer Science' 카테고리의 다른 글
LLM을 medical text에 활용해보면서 느낀점 및 정리 (0) | 2024.09.02 |
---|---|
JupyterLab 에서 함수 클래스 숨기기 (Toggle, Folding ) (0) | 2024.01.12 |
Google Drive 파일 Linux 서버에서 다운로드 (0) | 2023.12.23 |
AI 연구원이 ChatGPT 활용하는 꿀팁 (0) | 2023.12.22 |
윈도우 탐색기에 SSH 서버 폴더로 등록하기 (1) | 2023.12.22 |