본문 바로가기

Research

[Paper Review] ResNeSt: Split-Attention Networks

본인 이해를 목적으로 하기 때문에 단순 번역기 수준이 될 수도 있음.

 

2021.01.06 Sementic Segmentation 분야를 기준으로, 상위권 논문(모델) 중 하나를 정리하고자 한다.

현 시간 기준 Pascal Context, Cityscapes val, ADE20K val, ADE20K 데이터셋에 대해 Rank 1위

저자들이 각양각색이다. Facebook, UC Davis, Snp, Amazon, ByteDance, SenseTime ... 무슨 조합인지 모르겠다.

2020.4.19에 1차 버전이 제출되었다. 오늘 정리는 두 번째 제출본 v2를 보고 쓴다.

 

1. Introduction

 CNN의 주요 핵심은 convoluation layer에서 filter set을 학습시키는 데 있고, Conv. Layer는 주변 공간적인 방향과 채널 연결된 방향으로 각각 정보들을 더한다. 앞서 제안된 Inception model들은 서로 비의존적인 특징들을 학습하기 위해 여러 갈래의 표현을 사용한다(explore the multi-path representation). 따라서, input feature가 서로 다른 convolution filter로 이루어진 좀 더 저차원들로 나뉘는 방식을 사용한다. 이 전략은 입력 채널 간의 연결을 분리함으로써 좀 더 특징들을 보는데 유리해진다. 이 논문에서는 채널 방향의 attention 전략과 multi-path 네트워크 구성을 결합한 간단한 구조를 제안한다. 또한 학습 가속을 위해 ResNet 스타일의 변형된 연산 블록인 Split-Attention Block을 제안한다.

 

2. Split-Attention Networks

2.1 Split-Attention Block

 이 블록은 featuremap groupsplit attention operation으로 이루어진다. 아래 그림 1을 보면, 왼쪽 SE-Net Block과 SK-Net 블록 그리고 우측에 제안하는 ResNeSt Block이 있다. 결과적으로 cardinal groups이라는 featuremap group들을 만들어내야 되는데, 그림 1에서 나온 바와 같이 주어진 feature들을 여러 개의 그룹으로 나누는 것으로부터 시작한다. 이 featuremap group은 하이퍼 파라미터 K이다. 이 논문에서는 먼저 새롭게 정의한 기수(radix) 하이퍼 파라미터 R을 이용하여 하나의 cardinal group  내에서 몇 개의 split이 있는지를 나타낸다. 따라서, 총 feature group의 수 G = KR 이 된다. 각 그룹에 대해서는 연속된 transformation을 적용한다. 

 

Figure 1. Comparing ResNeSt block with SE-Net and SK-Net

 

 아래 그림 2는 하나의 cardinal group에 있는 Split Attention을 보여준다. 위 그림 1에서 Cardinal 내부 Split들이 연산을 수행하고 결과물들이 모아지는 하단 Split Attention 박스에 해당한다.

 먼저 그림 2의 Split attention을 설명하자면, 우선 input tensor들을 모두 summation 한다. 이 부분에 대해선 concat이 아닐까 생각했으나.. 명백히 (h, w, c)가 있으므로.. summation 한다.(concat을 한다면 BN에만 의미 있을 듯)

 

 그리고, Global Average Pooling으로 1x1 xc로 만든다. 이때 이미 h*w 크기의 feature들이 겨우 1개의 값이 되므로 정말 많은 데이터가 손실된다. 따라서 당연하게도 나중에 원래 데이터를 보존하는 tensor과 더한다.

 이어서, c에 대해 BN + ReLu를 적용한다. 여러 개의 Dense c를 생성한다. 각각 SoftMax를 거치고 나서 처음 Input tensor과 multiply 연산을 한다. 쉽게 말해, input tensor의 각 채널별 element들은(h x w) 1개의 scalar 값이 곱해지게 된다. 그 output tensor들은 다시 최종 합쳐지게 된다.(summation) 그럼 input과 마찬가지로 HxWxC tensor가 생성된다. 코딩 시 각 Input tensor들은 마지막 단계에서 다시 더해지기 위해 각자의 path를 기억해서 연산하는데 주의하면 될 듯하다.

 

Figure 2. Split-Attention within a cardinal group

 

Split Attention in Cardinal Groups.

 이 부분이 사실상 핵심이다. 논문 그림과 식을 매칭 해서 보자.

우선 그림 1에서 Split에서 1x1 Conv와 3x3 Conv를 거친 (h, w, c) tensor들이 이 Split-attention의 input이 된다.

먼저 논문에 언급된 첫 번째 식(그림 3의 가장 상단)은 단순히 element summation 되는 식이다.

그다음은 그냥 Global pooling이다. Average이기 때문에 1/(HxW)를 곱한다. global pooling을 정확히 이해한다면 더 이상 설명이 필요 없다.

마지막 식(2)은 마지막 최종 더해지는 계산 과정을 나타낸다. U는 각 카디널과 split의 인덱스 관련해서 나타내고 이걸 가중치 'a'와 곱한다. a를 구하는 식이 식(3)이다. 

 그림과 연관하여 정리하자면, 결국 마지막에 우리는 채널별 값 1개씩을 input tensor의 채널별 요소에 모두 곱하고 최종적으로 이 tensor들을 다 더하는 것인데, 이 채널별 1개씩의 값을 갖는 1x1 xC tensor는 global average pooling을 거친 s를 mapping g를 거치고 나서 soft max를 거친 것이다. mapping g는 FC layer + BN + Relue + FC Layer로 구성된 듯하다.

a를 구할 때는 R=1일 때와 달리 (R= Split attention 수) 통과하는 게 여러 개인 경우 R>1일 때는 모두 합친 후 그것에 대한 해당 split가 가지는 값의 비율을 구하는 식이다. R=1일 때는 그림 1에서의 SE-NET이나 SK-NET을 뜻하는 것으로 생각된다. 논문에는 이와 관련된 내용이 없다.

 여기서 Dense는 fully connected layer이고, 이에 대한 설명은 3.1. Instantiation and Computational Costs에 언급된다. 

 

Figure 3. Split-Attention and corresponding calculations

(여담으로.. 그림 3은 논문 읽다 보니 헷갈려서 자연스럽게 그림과 매칭하게 되었고 이걸 블로그에 올릴 때 만든 것인데... 우연하게 구글링 중 어느 한 외국 블로그에서도 똑같은 방식으로 편집했더라..... 마치 내가 베낀 것처럼!!! 당황스럽네;;;;;;)

 

ResNeSt Block.

 ResNet과 헷갈린 작명이다.. 다시 그림 1을 참고하여 설명하면, cardinal group들은 결국 concat 되어 그대로 tensor들을 붙인다. 또한 ResNet의 아이디어이자 많은 레이어 및 모듈에 적용되고 있는 그 유명한 shortcut을 이 곳에서도 사용한다.

 만약 stride를 갖는 블록이라면 최종 output Y tensor를 구할 때, cardinal group들의 concat 결과인 V와 처음 Input을 shortcut 한 X를 더할 때 X를 적절하게 transformation 한다.

(shape 맞추는 작업이니 알아서 하면 될 듯. 논문에도 말했듯 예를 들어 strided convolution이나 pooling을 함께 사용하면 됨)

식으로 표현하면  그냥 Y = V + T(X), T=transformation

 

Instantiation and Computational Costs.

 parameters수와 FLOPS(연산량)이 standard residual block과 거의 같다고 한다.

 

Relation to Existing Attention Methods.

(1) SE-Net: squeeze 그리고 attention 아이디어가 적용된 SE-Net.. 해당 논문에서는 attention이 아닌 excitation이라는 용어 사용. channel-wise attention factor를 예측하는 글로벌 콘텍스트를 사용한다. radix=1을 가지고 이 논문에서 제안하는 Split-Attention block이 마찬가지로 SE-Block에도 사용될 수 있다. 

(2) SK-Net: 두 개의 네트워크 stream 사이에서 사용하는 feature attention을 제안한다. radix=2로 세팅하여 Split Attention block이 적용될 수 있다. 그림 1번에 나온 바와 같이 SE-Net SK-Net에 적용된 모습을 볼 수 있다.

 

 

 

2.2 Efficient Radix-major Implementation

 cardinality-major 구현은 간단하고 직관적이지만 표준 CNN 연산자를 사용하여 모듈화하고 가속화하기가 어렵다. 이를 위해 동등한 radix-major implementation을 제안함...

구현할 때 인덱스 관리에 있어서 효율적인 면을 말한다... 중요하진 않다...

앞에서는 단순히 cardinal의 index를 기준으로 했으나 합치는 과정에서 헷갈릴 수 있다 하여 cardinality를 중심으로 인덱스를 관리하겠다는 것...

 솔직히 논문에 새롭게 나온 Radix-major implementation의 레이어 구성이 앞선 그림 1과 달라서 이해가 안 가지만..

인덱스 관리와 함께 블록 내부의 cardinal(또는 split.. 논문에서 마음대로 바꾼다) 별로 fully connected(FC) layer를 연결한다. 필요 FC 레이어는 각각 따로 있어 cardinality 만큼 존재한다.(그림 1을 기준으로 하면... split r이 되겠다..)

 

 FC 대신 Convolution Layer를 사용하고, grouped convolution으로 구현된 것은 RK의 그룹 수만큼 존재한다....

그러므로, Split-Attention block은 기존 CNN에서 사용하는 연산들을 가지고 모듈화 할 수 있다.

 

 

3. Network and Training

패스. 이것도 제안하는 네트워크 디자인과 학습 전략이 있으나 일단 난 패스.

 

 

4. Transfer Learning Result - Segmentic Segmentation

논문의 6.3에 해당한다. 

segmentation의 전이 학습을 할 때는 논문에서는 GluonCV로 구현된 DeepLabV3을 baseline으로 깐다. 

일단 method를 DeepLabV3로 고정하고, backbone을 ResNet을 적용할 때와 ResNst를 적용해본 결과.. 그냥 좋았다.

(per-pixel cross entropy loss 사용. multi-scale evaluation 사용)

 

Citscapes validation set을 기준으로, 

ResNet101(mIoU=79.42) 인 반면,

제안하는 ResNSt-50(mIoU=79.87%) 과 ResNeSt-101(mIoU=80.42%) 를 보여준다.

 

 

 

5. 리뷰 마치며

실험과 관련된 부분이나 결론 등 분명 정리할 내용이 더 있지만 일부러 더 적지 않는다.

 

 

 

 

6. 코드 관련

공식 github? 로 보이는 곳 github.com/zhanghang1989/ResNeSt

영현님이 구현하신 github.com/YeongHyeon/ResNeSt-TF2

심플하기는 영현님이 좋아 참고했다.

 

공식 github의 경우 예제 코드 resnest.py 에서 정의된 resnest-50, 101, 200, 269 모두 cardinal=1, radix=2 만 되어있다.

실제 논문에도 classification 에서는 2s2x(radix=2 caridinality=2) 를 사용하고, 빠른 버전의 경우 4s1x를 사용.

빠른 버전의 경우 4s1x를 사용. 결국..radix와 cardinate를 분류하고 이걸 관리하고 어쩌고 하면서 다 해봤자 4개 뿐;;;

영현님은 4s4x를 만들어 놓았다. 그리고 영현님은 FC를 convolution으로 구현.

논문을 샅샅이 뒤지진 않았지만 ..따로 언급이 없는 것으로 보아..2s2x가 표준인듯...