Programing/Python programming

Deep Attention-Sampling Models: 큰 이미지의 End-to-end 학습

sosal 2019. 8. 10. 01:57
반응형

Processing Megapixel Images with Deep Attention-Sampling Models

 

https://arxiv.org/pdf/1905.03711.pdf

불러오는 중입니다...

(참고 자료) https://icml.cc/media/Slides/icml/2019/halla(11-11-00)-11-11-25-4512-processing_mega.pdf

 

 Image-net에서 pretrained 되는 대표적인 CNN 모델들 (VGG, ResNet, DenseNet 등.. ) 은 기본적으로 224*224 사이즈를 가진다. 그런데 448*448 로만 늘려도 필요한 메모리가 어마어마해진다. 따라서 매우 큰 이미지를 학습할 때, 사이즈를 down-sampling 하거나, 아니면 patch를 뜯어서, 데이터를 나름 전처리 하여 정제를 한 후에 학습에 사용한다.

 

 Patch를 뜯어서 학습하는 방법은 Computational resource를 굉장히 낭비한다거나, 아니면 추가적인 patch들에 대한 각각의 Labeling을 필요로 하게 된다.

 

이 연구는 화질이 낮은, Thumbnail 같은 이미지로부터 'Attention distribution'을 만든다.

'informative patches' 'informative patches'를 뽑아내는 기준이 된다.

그 이후, 뽑힌 이미지의 fraction으로부터 prediction을 수행한다.

 

 

Processing Megapixel Images with Deep Attention-Sampling Models

 

위 그림에서 요약되어 있듯, 이 모델은 'Attention Network'와 'Feature Network' 2가지의 네트워크로 구성되어 있다.

attention_sampling함수는 참조: https://github.com/idiap/attention-sampling/blob/master/ats/core/ats_layer.py

 

attention_sampling 함수로부터, Attention network에서 중요한 patch와, patch의 feature, 그리고 attention map 3가지를 반환한다.

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

def get_model(outputs, width, height, scale, n_patches, patch_size, reg):

        # Define the shapes
        shape_high = (height, width, 1)
        shape_low = (int(height*scale), int(width*scale), 1)
    
 
        # Make the attention and feature models
        attention = Sequential([
            Conv2D(8, kernel_size=3, activation="tanh", padding="same",
                   input_shape=shape_low),
            Conv2D(8, kernel_size=3, activation="tanh", padding="same"),
            Conv2D(1, kernel_size=3, padding="same"),
            SampleSoftmax(squeeze_channels=True, smooth=1e-5)
        ])
        feature = Sequential([
            Conv2D(32, kernel_size=7, activation="relu", input_shape=shape_high),
            Conv2D(32, kernel_size=3, activation="relu"),
            Conv2D(32, kernel_size=3, activation="relu"),
            Conv2D(32, kernel_size=3, activation="relu"),
            GlobalMaxPooling2D(),
            L2Normalize()
        ])
    
 
        # Let's build the attention sampling network
        x_low = Input(shape=shape_low)
        x_high = Input(shape=shape_high)
        features, attention, patches = attention_sampling(
            attention,
            feature,
            patch_size,
            n_patches,
            replace=False,
            attention_regularizer=multinomial_entropy(reg)
        )([x_low, x_high])
        y = Dense(outputs, activation="softmax")(features)
    
 
        return (
            Model(inputs=[x_low, x_high], outputs=[y]),
            Model(inputs=[x_low, x_high], outputs=[attention, patches])
        )
 
 

Colored by Color Scripter

cs

위 소스를 보면, outputs 라는 변수 multi-class classification에 해당하는 label 종류의 수이다.

'attention' network와 'feature' network는 attention_sampling 함수를 통해 patch를 선정하게 된다.

반환되는 features 값은, 이미 flat된 patch로부터 얻어지는 값이다. 이를 통해 label을 예측하는 'y' 라는 최종 output이 나오는 것이다.

 

리턴되는 모델 두가지는, 바로 첫번째, mnist의 숫자를 내뱉는 y를 output으로 가지는 feature network와,

attention map과 patch를 내뱉는 attention network이다.

 

attention network는 attention_sampling 함수 안에서, low resolution image를 받아 attention_map을 뿌리게 되는데, 이 값은 결국 feature network의 input으로 쓰이게 되므로 학습에 의해 더 좋은 patch를 뽑도록 network가 학습된다.

 

학습이 힘든 매우 큰 이미지에서, 똑똑한 방식으로, 큰 이미지로부터 중요한 위치를 찾아내는 네트워크를 통해 필요한 일부분만을 떼어내서 prediction 할 수 있게 해주는 매우 유용한 네트워크이다.