Imbalanced Text Data Resampling 후 huggingface 🤗 로 학습하기

2021. 6. 30. 15:57nlp

반응형

Imbalanced Data가 있으면 이를 Resampling 해주어야 제대로 학습이 가능하다.

Resampling은 크게 (1) Undersampling (2) Oversampling으로 나눌 수 있다

예를 들어 label이 0인 데이터는 1,234개, label이 1인 데이터는 5,678개 있다면,

label 0은 전체 데이터의 17.8%, label 1은 전체 데이터의 82.14%이므로 Imbalanced Data이다.

이를 (1) Undersampling하면 크기가 작은 label 0을 기준으로 데이터의 크기를 통일한다. 

label 0도 1,234개, label 1도 1,234개로 통일하는 것이다. 

(2) Oversampling하면 크기가 큰 label 1을 기준으로 데이터의 크기를 통일한다.

label 0도 5,678개, label 1도 5,678개로 통일한다. 

 

이는 imblearn 이라는 library로 구현 가능하다. 

 

(1) Undersampling 후 huggingface Trainer에 알맞는 train_data 형식으로 만들기 

 

import torch
from imblearn.under_sampling import RandomUnderSampler
from transformers import Trainer 
...

def resampling_data(df):
    inputs = tokenizer(df[input에 해당하는 column].tolist(), return_tensors="pt", padding=True, truncation=True, max_length=256)
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    # (# of samples, # of features) 크기의 array, dataframe, sparse matrix 모두 가능 
    # 여기선 그냥 array 
    x = [[input_id, mask] for input_id, mask in zip(input_ids, attention_mask)]
    y = df[label에 해당하는 column].tolist()
    return x, y
    
rus = RandomUnderSampler(random_state=42, replacement=True) 
x, y = resampling_data(train_df)
x_rus, y_rus = rus.fit_resample(x, y)

print('original dataset shape:', Counter(y))
print('Resample dataset shape', Counter(y_rus))

class ResampledDataset(torch.utils.data.Dataset): 
    def __init__(self, x_rus, y_rus):
        self.input_ids = []
        self.attention_mask = []
        for input_id, mask in x_rus:
            self.input_ids.append(input_id)
            self.attention_mask.append(mask)
        self.labels = y_rus

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {'input_ids': self.input_ids[idx] ,'attention_mask': self.attention_mask[idx], 'labels':self.labels[idx]}
        
# huggingface Trainer에 전달할 수 있는 Dataset 
train_data = ResampledDataset(x_rus, y_rus)
   
...

trainer = Trainer(
   model=model,            
   args=training_args,            
   compute_metrics=compute_metrics, 
   train_dataset=train_data,       
   eval_dataset=eval_data      
   )

 

(2) Oversampling 후 huggingface Trainer에 알맞는 train_data 형식으로 만들기

 

import torch
from imblearn.over_sampling import RandomOverSampler
from transformers import Trainer 
...

def resampling_data(df):
    inputs = tokenizer(df[input에 해당하는 column].tolist(), return_tensors="pt", padding=True, truncation=True, max_length=256)
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    # (# of samples, # of features) 크기의 array, dataframe, sparse matrix 모두 가능 
    # 여기선 그냥 array 
    x = [[input_id, mask] for input_id, mask in zip(input_ids, attention_mask)]
    y = df[label에 해당하는 column].tolist()
    return x, y

ros = RandomOverSampler(random_state=42)

x, y = resampling_data(train_df)
x_ros, y_ros = ros.fit_resample(x, y)

print('Original dataset shape', Counter(y))
print('Resample dataset shape', Counter(y_ros))

class ResampledDataset(torch.utils.data.Dataset): 
    def __init__(self, x_rus, y_rus):
        self.input_ids = []
        self.attention_mask = []
        for input_id, mask in x_rus:
            self.input_ids.append(input_id)
            self.attention_mask.append(mask)
        self.labels = y_rus

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {'input_ids': self.input_ids[idx] ,'attention_mask': self.attention_mask[idx], 'labels':self.labels[idx]}
        
# huggingface Trainer에 전달할 수 있는 Dataset 
train_data = ResampledDataset(x_ros, y_ros)
   
...

trainer = Trainer(
   model=model,            
   args=training_args,            
   compute_metrics=compute_metrics, 
   train_dataset=train_data,       
   eval_dataset=eval_data      
   )

반응형

'nlp' 카테고리의 다른 글

자연어처리 워크샵에 페이퍼 내기  (0) 2021.11.25
GPT-3 API 받았네??  (0) 2021.07.30
Transformer 정리  (0) 2021.02.25
임베딩Embedding 정리  (0) 2021.02.25
알파벳으로 한글 쓰기 0r2#rld7lxolNJ 6rLrlN ^^-7l  (0) 2020.10.27