Set up training logics

We now have all the ingredients we need for the machine learning model, now is time to specify the training logic for the model to learn

We have all the components we need

  1. Raw Data: Loaded the SMILE dataset

  2. Split Data: We used stratified sampling to ensure generalized distribution

  3. Data For Machine Learning: we tokenized the data into X_train, X_val

  4. Model we are going to use: Model = BertForSequenceClassification.from_pretrained

  5. Tools to feed data: dataloader_train = DataLoader

  6. Tools to train: optimizer = AdamW ; scheduler = get_linear_schedule_with_warmup

  7. Tools to measure performance: f1_score_func ; accuracy_per_class

Check environment, let model know whether to use GPU or even TPU

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print(device)

Customize evaluate function for logging:

The main training happens in model.eval()

We customize the evaluate function to evaluate and log the validation performance

def evaluate(dataloader_val):
    """
    This function is customize some output of the evaluation result
    """


    # call the default evaluation function
    model.eval()
    
    loss_val_total = 0
    predictions, true_vals = [], []
    
    
    # loop through validation set, evaluate and show performance
    for batch in tqdm(dataloader_val):
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2],
                 }

        with torch.no_grad():  # disabled gradient calculation, because we just want to evaluate here      
            outputs = model(**inputs)
        
        # recored each loss and aggregate    
        loss = outputs[0]
        loss_val_total += loss.item()

        # record all prediction vs real labels        
        logits = outputs[1].detach().cpu().numpy()
        label_ids = inputs['labels'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)
    
    loss_val_avg = loss_val_total/len(dataloader_val) 
    
    predictions = np.concatenate(predictions, axis=0)
    true_vals = np.concatenate(true_vals, axis=0)
            
    return loss_val_avg, predictions, true_vals

Training Loop:

tqdm : a progress bar library. tqdm derives from the Arabic word taqaddum (تقدّم) which can mean "progress," and is an abbreviation for "I love you so much" in Spanish (te quiero demasiado).

Logics of the loop:

  1. For the training, iterate through all epochs

  2. For the epoch, train on each batch

  3. Log the training performance on that epoch

for epoch in tqdm(range(1, epochs+1)):
    """
    for each epoch, we do a training all all instance, batch by batch
    and we also log the performance of training for that epoch
    """


    # ask the model to train
    model.train()
    
    # specify information for the progress bar
    loss_train_total = 0
    
    progress_bar = tqdm(dataloader_train, 
                        desc=f"Epoch {epochs}",
                        leave=False,
                        disable=False
                       )
    
    # similar to the evaluation function
    for batch in progress_bar:
    
        model.zero_grad()
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {
            'input_ids': batch[0],
            'attention_mask': batch[1],
            'labels': batch[2]
        }
        
        outputs = model(**inputs)
        
        loss = outputs[0]
        loss_train_total ++ loss.item()
        loss.backward()
        
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        # proceed to next step
        optimizer.step()
        scheduler.step()
        
        progress_bar.set_postfix({'training_loss': '{:3f}'.format(loss.item()/len(batch))})
    
    # save the model
    torch.save(model.state_dict(), f"Models/BERT_ft_epoch{epoch}.model")
    
    # write training information
    tqdm.write('\nEpoch {epoch}')
    
    loss_train_avg = loss_train_total/len(dataloader)
    
    tqdm.write(f'Training loss: {loss_train_avg}')
    
    val_loss, predictions, tru_vals = evaluate(dataloader_val)
    val_f1 = f1_score_func(predictions, true_vals)
    tqdm.write(f'validataion loss: {val_loss}')
    tqdm.write(f"F1 score {val_f1},")
        
        

Last updated