Search
Close this search box.
Search
Close this search box.

Blog

Fine-tuning BERT model for arbitrarily long texts, Part 2

Date of publication: 2 years ago

Share this:

Fine-tuning BERT model for arbitrarily long texts, Part 2

Author: Michał Brzozowski

 

This is part 2 of our series about fine-tuning BERT:

  • if you want to read the first part, go to this link,
  • and if you want to use the code, go to our GitHub.

 

Fine-tuning the pre-trained BERT on longer texts

 

Now, this is the time to address the elephant in the room of the previous approach. We were lucky to find the already fine-tuned model for our IMDB dataset. However, more often, we are in a more unfortunate situation when we have the labelled dataset and we need to fine-tune the classifier from scratch. In this case, we use the supervised approach, we download the general pre-trained model, put the classification head on top of it and train it on our labelled data.

 

Three roads to follow

 

The procedure of fine-tuning the classifier model is described in detail in the amazing huggingface tutorial.

Let us now briefly summarise the main steps. In the beginning, we tokenise our text. Again, the standard approach is to truncate all the tokens to 512 tokens. After this preprocessing stage, there are three ways to fine-tune the model:

 

We will follow the last way because it is the most explicit and, thanks to it, we will be able to adjust it to our needs. The main goal is to modify the procedure to avoid truncating longer texts.

 

The main idea

 

The problem of using and fine-tuning BERT for longer texts was discussed here. The main idea to solve it was described in the comment by Jacob Devlin, one of the authors of BERT.

Let us emphasise the following part of the comment:

 

So from BertModel’s perspective this is a 3×6 minibatch

 

It tells us the crucial difference between what we were doing in the previous section about applying the model and what we need to do now.

Recall that to apply the fine-tuned classifier model to a single long text, we first tokenised the entire sequence, then splitted it into chunks, got the model prediction for each chunk and calculated the mean/max of predictions. There was no problem in doing it sequentially, that is:

  • Put the 1st chunk to the model, get 1st prediction.
  • Put the 2nd chunk to the model, get 2nd prediction.
  • so on…
  • Take the mean/max of these predictions and stop.

 

However, training it sequentially on each chunk leads to myriad problems and questions:

  • Put the 1st chunk of the 1st text to the model, calculate the loss of the prediction and the label…
  • What label?
  • We have only one binary label for the entire text… Then maybe run backpropagation? But when?
  • Should we really update the model weights after each chunk?

 

Instead, we must do it all at once by putting all the chunks into one mini-batch. This solves all our problems:

  • From K chunks obtained for the 1st text, create 1 mini-batch and obtain K predictions.
  • Pool the predictions using the mean/max function to obtain a single prediction for the entire text.
  • Calculate the loss between this single prediction and the single label.
  • Run backpropagation. Be careful to make sure that all the tensor operations are done on tensors with attached gradients before running loss.backward() .

 

Usual fine-tuning

 

Now, we will sketch how to modify the procedure from the tutorial. The basic steps of fine-tuning are:

  • Tokenize the texts of the training set with truncation. Roughly speaking, the tokenized set is the dictionary with keys input_ids and attention_mask and values being tensors of the size precisely equal to 512.
  • Create the Dataloader object with the selected batch_size. This will allow us to iterate over batches of data. In other words, assume the batch_size=N.
  • During the training loop for batch in train_dataloader, we will get the object batch. The batch here again is the dictionary with keys input_ids and attention_mask. But this time, its values are stacked tensors of the size N x 512.
  • Put each loaded batch to the model with outputs = model(**batch) , calculate loss with loss = outputs.loss and run backpropagation loss.backward() .

 

In what follows, we will describe how to change each stage.

 

Tokenization with splitting

 

Recall the function transform_single_text we used to tokenize a single text. Now we want to have a vectorized version of this to work with a list of texts:

GitHub: https://gist.github.com/MichalBrzozowski91/d6b0c0c170e7be3957dff65a1449ded3#file-transform_list_of_texts-py

 

As always, it is instructive to get our hands dirty with an example. Let’s look at the result of this function for one short and one long review and compare it with the usual truncation approach:

GitHub: https://gist.github.com/MichalBrzozowski91/4d82bfdc3271f52c11e8600a496d6f8a#file-truncation_vs_splitting-ipynb

 

The key observation here is that our tokenization returns lists of tensors of different sizes because the texts can be of different length. Unfortunately, we cannot stack together tensors of different sizes. In the same way, we cannot concatenate two vectors of different sizes into a rectangular matrix (as we all remember from kindergarten).

From now on, we must be very careful not to make what philosophers call the category mistake and what mortal programmers call the type error.

 

Creating the dataset and the dataloader

 

The next step is to put the tokenized texts into the torch Dataset object. We define it as follows:

GitHub: https://gist.github.com/MichalBrzozowski91/ece08785c958691fd8c6fbb2188ab2c5#file-tokenized_dataset-py

 

Again, let us try it with our toy example of two reviews:

GitHub: https://gist.github.com/MichalBrzozowski91/d9bcfbed6b7aa1f8ec9eda850dd1cb82#file-datasets_and_dataloaders-ipynb

 

That is no good… It turns out that the default behaviour of torch Dataloader forbids using input tensors of different sizes! After some googling, we find the following discussion of precisely that problem.

 

Overriding the default dataloader

 

After analysing the linked discussion, we decided to override the default behaviour of the Dataloader by creating the custom collate_fn function. Again let’s look at the code:

GitHub: https://gist.github.com/MichalBrzozowski91/3b5642894d4d5889bafaa46e09d50210#file-collate_fn-ipynb

 

The custom function collate_fn_pooled_tokens just forces torch to treat each batch as a list of (potentially different sized) tensors and forbid it from trying to stack them.

We are finally ready to look at the training loop.

 

Modifying the training loop

 

The standard torch training loop for the classifier model looks like that:

GitHub: https://gist.github.com/MichalBrzozowski91/19ddf78c1112b5b2ee75edb03c6b6051#file-train_single_epoch-py

 

where the crucial method _evaluate_single_batch is defined as:

GitHub: https://gist.github.com/MichalBrzozowski91/5b870014208bc4b7e8e7e83ecf4f7bc0#file-evaluate_single_batch_trauncated-py

 

Here the self.neural_network is the classifier model returning single probability.

To adapt the training loop to the situation where each batch is the list of tensors with different sizes, we needed to make some adjustments:

GitHub: https://gist.github.com/MichalBrzozowski91/728dfb28f33cb01f5291067ceaeef3fc#file-evaluate_single_batch_splitted-py

 

Some comments are in order:

  • during training we basically do the same steps as during prediction, the crucial part is that all the operations of the type cat/stack/split/mean/max are done on tensors with attached gradient.
  • For that, we use built-in torch tensor transformations. Any intermediate conversions to lists or arrays are not allowed. Otherwise, the key backpropagation command loss.backward() won’t work.

 

Conclusions

 

In this article, we’ve learnt how to elongate the input of BERT either in application or fine-tuning. I invite you to check out our repository, where you can find all the code used in this tutorial.

If you have any questions, please reach me via my LinkedIn.

Other posts

Breaking news from MIM Solutions

Blog

April Fools’

Remember, on April Fools’ Day you should be intelligent, not artificial! We wish you only good jokes!

Events

ESHRE ANNUAL MEETING 2022

Since Sunday MIM Solutions has been at 38th ESHRE ANNUAL MEETING in Milan – the world’s largest in-person reproductive science and medicine event organized by

Photo of iso13485 certification
News

ISO for medical devices is ours!

Since October 7th, MIM Solutions is officially confirmed to be compliant with ISO13485!   What is ISO13485?   It is a standard published by International