TRANS-BLSTM를 활용한 Text Classification (KLUE Dataset) 실험
- 이전에 논문 리뷰 글을 올렸던 TRANS-BLSTM 구조를 코드로 구현해보고 한국어 뉴스 기사 타이틀로 7개의 카테고리를 분류하는 Classification Task에 적용하여 성능비교를 진행하였다.
❒ Code
1. Tokenize
- Sentence Piece
: vocab size를 6000으로 설정하고 unigram 기반으로 sentence piece tokenizer를 학습하여 활용
with open('nlp.txt', 'w', encoding='utf8') as f:
f.write('\n'.join(trainset['input']))
corpus = 'nlp.txt'
prefix = "nlp"
vocab_size = 6000
spm.SentencePieceTrainer.train(
f"--input={corpus} --model_prefix={prefix} --vocab_size={vocab_size + 5}" +
" --model_type=unigram" + # unigram (default), bpe, char, word
" --max_sentence_length=9999" +
" --pad_id=0 --pad_piece=[PAD]" +
" --unk_id=1 --unk_piece=[UNK]" +
" --bos_id=2 --bos_piece=[BOS]" +
" --eos_id=3 --eos_piece=[EOS]" +
" --user_defined_symbols=[MASK]")
vocab_list = pd.read_csv('nlp.vocab', sep='\t', header=None)
vocab_list.reset_index(inplace=True)
sp = spm.SentencePieceProcessor()
vocab_file = "nlp.model"
sp.load(vocab_file)
2. Encoder
- Vanilla Transformer Encoder
: 초기 제안되었던 Post-LN 방식과 최근 주로 활용된다고 하는 Pre-LN 방식 포함
class EncoderBlock(nn.Module): def __init__(self, embed_dim, num_heads, ffn_dim, dropout): super(EncoderBlock, self).__init__() self.self_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads,batch_first=True) self.ffn = nn.Sequential(nn.Linear(embed_dim,ffn_dim), nn.ReLU(), nn.Linear(ffn_dim, embed_dim)) self.layernorm1 = nn.LayerNorm(embed_dim,eps=1e-5) self.layernorm2 = nn.LayerNorm(embed_dim,eps=1e-5) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) def forward(self, input, method, key_padding_mask): if method == 'pre_LN': norm_input = self.layernorm1(input) attn_output, _ = self.self_att(norm_input, norm_input, norm_input, key_padding_mask=key_padding_mask) attn_output = self.dropout1(attn_output) input2 = input + attn_output norm_input2 = self.layernorm2(input2) output = self.ffn(norm_input2) output = self.dropout2(output) return input2+output elif method == 'post_LN': attn_output,_ = self.self_att(input, input, input, key_padding_mask=key_padding_mask) attn_output = self.dropout1(attn_output) input2 = self.layernorm1(input + attn_output) output = self.dropout2(self.ffn(input2)) return self.layernorm2(input2 + output)
- Transformer Encoder + Bidirectional LSTM
: TRANS-BLSTM 논문에서 제안된 아키텍쳐 구현 TRANS-BLSTM: Transformer with Bidirectional LSTM for Language Understanding
class LSTMEncoderBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ffn_dim, dropout):
super(LSTMEncoderBlock, self).__init__()
self.self_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads,batch_first=True)
self.ffn = nn.Sequential(nn.Linear(embed_dim,ffn_dim), nn.ReLU(), nn.Linear(ffn_dim, embed_dim))
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.lstm = nn.LSTM(embed_dim, embed_dim, batch_first=True,bidirectional=True)
self.linear = nn.Linear(embed_dim*2,embed_dim)
self.layernorm1 = nn.LayerNorm(embed_dim,eps=1e-5)
self.layernorm2 = nn.LayerNorm(embed_dim,eps=1e-5)
def forward(self, input, method, key_padding_mask):
if method == 'ver_1':
attn_output,_ = self.self_att(input, input, input, key_padding_mask=key_padding_mask)
attn_output = self.dropout1(attn_output)
output1 = self.layernorm1(input + attn_output)
output2 = self.lstm(output1)
output2 = self.linear(output2[0])
return self.layernorm2(output1 + output2)
elif method == 'ver_2':
attn_output,_ = self.self_att(input, input, input, key_padding_mask=key_padding_mask)
attn_output = self.dropout1(attn_output)
output1 = self.layernorm1(input + attn_output)
output2 = self.dropout2(self.ffn(output1))
output3 = self.lstm(input)
output3 = self.linear(output3[0])
return self.layernorm2(output1 + output2+ output3)
- Transformer Encoder + CNN
: LSTM이 아닌 CNN을 Encoder 구조에 적용
class CNNEncoderBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ffn_dim, dropout,kernel_size,padding):
super(CNNEncoderBlock, self).__init__()
self.self_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads,batch_first=True)
self.ffn = nn.Sequential(nn.Linear(embed_dim,ffn_dim), nn.ReLU(), nn.Linear(ffn_dim, embed_dim))
self.layernorm1 = nn.LayerNorm(embed_dim,eps=1e-5)
self.layernorm2 = nn.LayerNorm(embed_dim,eps=1e-5)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.linear = nn.Linear(embed_dim*2,embed_dim)
self.cnn_1 = nn.Conv1d(embed_dim, embed_dim,kernel_size= kernel_size,padding=padding)
def forward(self, input, key_padding_mask):
attn_output,_ = self.self_att(input, input, input, key_padding_mask=key_padding_mask)
attn_output = self.dropout1(attn_output)
input1 = self.layernorm1(input + attn_output)
input2 = self.cnn_1(input1.transpose(-1,-2).contiguous()).transpose(-1,-2).contiguous()
output = self.dropout2(input2)
return self.layernorm2(input2 + output)
3. Final Modeling
- Positional Encoding
class PositionalEncoding(nn.Module):
def __init__(self,emb_size: int,maxlen: int = 128,dropout=0.1):
super(PositionalEncoding, self).__init__()
den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
pos_embedding = torch.zeros((maxlen, emb_size))
pos_embedding[:, 0::2] = torch.sin(pos * den)
pos_embedding[:, 1::2] = torch.cos(pos * den)
pos_embedding = pos_embedding.unsqueeze(0)
self.register_buffer('pos_embedding', pos_embedding)
self.dropout = nn.Dropout(dropout)
def forward(self, token_embedding):
return self.dropout(token_embedding + self.pos_embedding[:,:token_embedding.size(1), :])
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, emb_size):
super(TokenEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.emb_size = emb_size
def forward(self, tokens):
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
- Classifier(Decoder)
1) TRANS-BLSTM의 경우, 논문에서 언급한대로 BiLSTM Decoder로 적용하여 최종 output을 산출
2) 이외의 경우, 아래 이미지와 같이 sequence representation 정보를 aggregate되도록 학습된 첫번째 'CLS' 토큰 정보를 활용Text Classification
class TransformerClassifier(nn.Module):
def __init__(self,input_vocab_size,embed_dim,num_heads,ffn_dim, N, dropout, hid_dim,output_dim,max_len,method):
super(TransformerClassifier,self).__init__()
self.input_tok_emb = TokenEmbedding(input_vocab_size, embed_dim)
self.positional_encoding = PositionalEncoding(embed_dim)
self.encoder_layers_lstm = clones(LSTMEncoderBlock(embed_dim = embed_dim,num_heads = num_heads,ffn_dim = ffn_dim, dropout=dropout),N)
self.encoder_layers_cnn = clones(CNNEncoderBlock(embed_dim = embed_dim,num_heads = num_heads,ffn_dim = ffn_dim, dropout=dropout,kernel_size=3,padding=1),N)
self.encoder_layers = clones(EncoderBlock(embed_dim = embed_dim,num_heads = num_heads,ffn_dim = ffn_dim, dropout=dropout),N)
self.linear1 = nn.Linear(embed_dim,hid_dim)
self.linear2 = nn.Linear(hid_dim,output_dim)
self.linear3 = nn.Linear(hid_dim*2,output_dim)
self.relu = nn.ReLU()
self.lstm = nn.LSTM(embed_dim, hid_dim, batch_first=True,bidirectional=True)
self.method = method
def forward(self, input, input_padding_mask):
if self.method=='lstm':
x = self.positional_encoding(self.input_tok_emb(input))
for layer in self.encoder_layers_lstm:
x = layer(x,'ver_1', input_padding_mask)
x = self.lstm(x)
x = self.linear3(x[0][:,-1,:])
return x
elif self.method=='cnn':
x = self.positional_encoding(self.input_tok_emb(input))
for layer in self.encoder_layers_cnn:
x = layer(x, input_padding_mask)
x = x[:,0,:]
x = self.relu(self.linear1(x))
x = self.linear2(x)
return x
else:
x = self.positional_encoding(self.input_tok_emb(input))
for layer in self.encoder_layers:
x = layer(x,'post_LN', input_padding_mask)
x = x[:,0,:]
x = self.relu(self.linear1(x))
x = self.linear2(x)
return x
❒ Result
3가지 방식의 Encoder 아키텍쳐(Vanilla Transformer, Encoder+LSTM, Encoder+CNN)를 적용하여 성능 비교를 진행하였다.
하이퍼 파라미터는 Batch size : 256 / 임베딩 차원 : 256 / 헤드 수 : 4 / FFN 차원 : 512 / LSTM, CNN hidden dim : 256 로 설정하고 Encoder Layer 수를 늘려가면서(2~6) F1-Score를 비교하였다. (5fold CV로 스코어 측정)
그 결과 아래 표와 같이 스코어가 산출되었으며, CNN을 적용한 Encoder 모델은 가장 낮은 성능을 보였다. Amazon AWS AI에서 제안한 TRANS-BLSTM 모델이 가장 높은 스코어를 보였지만 5개 레이어 이상부터 급격히 감소하는 모습을 보였다. 이는 데이터셋의 Input Sequence(뉴스 타이틀)가 짧으며, 매 Layer 마다 Bidirectional LSTM 연산과 Self Attention이 포함되어서 출력되는 sentence representation이 오히려 왜곡되어 학습됐지 않을까 생각되었다.