일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 |
- 1101
- n과 m
- 9020
- pyenv
- 밑바닥부터 시작하는 딥러닝
- Python
- 15649
- 재귀
- 티스토리챌린지
- 손실함수
- 개발환경
- end to end
- 4948
- streamlit
- 1002
- 그리디 알고리즘
- video retireval
- 백트래킹
- 기계학습
- Retrieval
- N-Queen
- REST API
- 경사하강법
- 백준
- 파이싼
- 신경망 학습
- 파이썬
- 오블완
- 가상환경
- BOJ
- Today
- Total
파이톨치
[BoostCamp AI Tech] smp 라이브러리 사용 본문
지금까지는 라이브러리를 사용하지 않고, 파이토치만으로 구현했다.
하지만 그렇게 하면 시간이 많이 들고, backbone 모델을 갈아끼울 때 마다 코드를 수정해주어야 하는 번거로움이 생기게 된다.
이러한 불편한 점을 방지하기 위해서 smp 와 같은 라이브러리를 사용하여 시간을 단축시킬 수 있다.
하지만, 그럼에도 데이터셋과 같은 클래스들을 우리가 작성해주어야 한다는 점은 잊지 말자.
왜냐하면 우리 폴더의 경로 등을 라이브러리가 알 수 없을 뿐만 아니라, 우리가 원하는 입출력을 위해서는
데이터셋 클래스의 __getitem__ 내부 함수를 작성해주어 데이터로더에 넣어주어야 한다.
라이브러리를 사용하면 특이한 점은 model(images)를 했을 떄, 곧바로 출력값이 나오는 것이 아니라 dict 형태로 나오게 되는데 이를 위해 outputs = model(images)['out']과 같이 명시해 주어야 한다는 점이다. 또한 grond truth와 크기가 다른 경우 interpolation (보간)을 해주어야 한다. 그리고 출력을 sigmoid를 지나서 outputs를 계산하고 thr 보다 클 때만 계산해준다. 왜 그런지는 모르겠다. 출력을 안정화 하고 쓸데없는 계산을 줄이는 용도일까? 암튼 model을 학습할 때도 비슷한 과정을 거친다.
단순한 pytorch를 사용해서 모델을 구현할 때는, 모델의 구조를 이해하고 어떤 층이 들어가는지 직접 작성해야 했다. 하지만 라이브러리를 사용하면 그 과정을 단순화 할 수 있게 된다. 사용방법은 다음과 같다.
model = smp.Unet(
encoder_name="efficientnet-b0", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=29, # model output channels (number of classes in your dataset)
)
이렇게 하면 모델을 사용할 수 있게 된다.
이 때 정보를 압축하기 위해서 RLE라는 방법을 사용하게 된다. 공부를 대충했더니 몰랐던 개념인데, 아래 블로그를 참고하자.
https://www.kaggle.com/code/leahscherschel/run-length-encoding
Run-Length Encoding
Explore and run machine learning code with Kaggle Notebooks | Using data from HuBMAP - Hacking the Kidney
www.kaggle.com
핵심만 말하자면, 연속된 픽셀에 대한 정보를 저장하는 것이다.
특정한 픽셀에서 아래 몇 픽셀까지가 연속되었는지를 저장하는 것이다. 그렇다면 그 픽셀들은 의미론적으로 같은 것일까?
1이면 mask이고 0이면 background라고 한다. 코드는 다음과 같다.
# mask map으로 나오는 인퍼런스 결과를 RLE로 인코딩 합니다.
def encode_mask_to_rle(mask):
'''
mask: numpy array binary mask
1 - mask
0 - background
Returns encoded run length
'''
pixels = mask.flatten()
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return ' '.join(str(x) for x in runs)
# RLE로 인코딩된 결과를 mask map으로 복원합니다.
def decode_rle_to_mask(rle, height, width):
s = rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(height * width, dtype=np.uint8)
for lo, hi in zip(starts, ends):
img[lo:hi] = 1
return img.reshape(height, width)