Attention is not always all you need!

Transforming the landscape: Fourier transforms as a replacement for the attention mechanism.

Large language models have been the talk of the town for quite some time now. Such models allow us to get deeper insights from the text data and extract meaningful excerpts based on our requirements. 

Ever since the launch of such large language models, BERT has been quite efficient and reliable in capturing the bidirectional context and delivering high-quality results. In the CRM domain, we use these language models for various tasks like sentiment classification, intent detection, phrase extraction, and out-of-office email detection. BERT models are fine-tuned with our custom CRM data explicitly for each of these tasks. In order to enhance the capabilities of these large language models, we first focus on CRM domain pre-training tasks, generating better contextual representations, and then test it further on relevant downstream tasks. However, BERT models come with their own set of disadvantages related to cost, training, and inference times due to the presence of self-attention layers. 

In this blog, we show the advantages of replacing self-attention layers in BERT with a non-parametric transformation called a Fourier transform, which achieves the same goals as attention layers of mixing tokens and affecting the embedding of a particular token by relatable ones in the entire text. We study the ability of FNETs to overcome drawbacks associated with BERT and compare the performance across different tasks. 

BERT architecture

A basic transformer consists of an encoder to read the text input and a decoder to produce a prediction for the task. The transformer encoder is composed of multiple layers of self-attention and feed-forward neural networks that transform the input text into a set of contextualized word embeddings. Since BERT’s goal is to generate a language representation model, it only takes advantage of the encoder part. BERT is pre-trained in two versions: 

  • BERT BASE: 12-layer, 768-hidden-nodes, 12-attention-heads, 110M parameters
  • BERT LARGE: 24-layer, 1024-hidden-nodes, 16-attention-heads, 340M parameters

It is quite evident that one of the main disadvantages of the BERT model is the computational complexity that is associated with multiple self-attention layers. Time taken for pre-training or fine-tuning these models with hundreds of millions of parameters can range anywhere between a few hours and several days. This can be a challenge for organizations that need to quickly deploy a model and don’t have the computing power or time to train it for a long time. Additionally, inference time can be slow since BERT models require a lot of computation to produce predictions. Lighter versions like DistilBERT can alleviate this problem, but only to a certain extent.

Replacing attention layers in BERT with Fourier transform layers

The success of BERT in achieving state-of-the-art results across various NLP tasks is often attributed to attention layers. While token-wise weights learned through attention layers are essential for high-quality context, recent research suggests that similar results can be achieved through alternative mechanisms. For example, some studies have replaced attention weights with unparameterized Gaussian distributions or fixed, non-learnable positional patterns and achieved minimal performance degradation while retaining learnable cross-attention weights. Moreover, recent efforts to improve attention efficiency are based on sparsifying the attention matrix or replacing attention with other mechanisms, such as MLPs.

Although the standard attention mechanism has a memory bottleneck with respect to sequence length, efficient transformers with O (NN) or even O (N) theoretical complexity such as Longformer, ETC, and BigBird have been developed. In the Long-Range Arena benchmark, Performer, Linear Transformer, Linformer, and Image Transformer (Local Attention) were found to be the fastest and had the lowest peak memory usage per device. Finally, the team at Google found that replacing attention layers with Fourier transform layers offered similar performance, reduced model size (no learnable parameters), and simplicity.  (Lee-Thorp, James, Joshua Ainslie, Ilya Eckstein, and Santiago Ontanon. “Fnet: Mixing tokens with Fourier transforms.” arXiv preprint arXiv:2105.03824 (2021). https://arxiv.org/pdf/2105.03824.pdf)

Enter Fourier nets

Fourier neural networks, or Fourier nets for short, are a class of neural networks that use the Fourier series as the basis of their architecture. These networks have gained popularity in recent years due to their ability to efficiently model complex, high-dimensional functions and their use in various applications such as image and audio processing.

Fourier nets use a Fourier series as the basis of their architecture. A Fourier series is a mathematical representation of a periodic function as a sum of sine and cosine functions of different frequencies. A discrete Fourier transform (DFT) is used to decompose this series into individual frequencies. Particularly for sentence embeddings, a DFT for a sentence sequence of n tokens can be written as:

                    

 

Xn is the nth input token of a sentence, and Xk is the transformed representation of the sum of all xn tokens with additional factors.  To compute (1), a fast Fourier transform methodology is used, which brings down the time complexity to O (NlogN) as opposed to the quadratic complexity that is associated with self-attention layers.

A normal BERT transformer has a multi-head self-attention layer with h heads, which looks like:

                              Y is the final multi-head attention. Q, K, and V are different representations of token embeddings in a sentence with dimensional size dkWiQ, WiK, and WiV are learnable parameters. For a detailed explanation, see this paper.

The Google team replaced the self-attention sublayer of each transformer encoder layer with a Fourier transform sublayer. This involved applying individual 2D Fourier transforms along both the sequence dimension and the hidden dimension, resulting in a complex number. 

 

                                              Where Y is the final transform equivalent to the final multi-head attention output. Fseq and Fhidden are the 1D FFT transforms. is the real number part of the 2D transform, meaning the feedforward and output layers did not need to be modified to handle complex numbers. Equations (2) and (5) clearly show that FNETs perform the desired mixing of tokens, which is equivalent to the self-attention mechanism but without the baggage of heavy matrix multiplications and the huge number of learnable parameters. 

Now, we show the performance of BERT vs FNETs across three different tasks that are specific to our CRM.

Tasks

1. MLM for domain pre-training:

The goal here is to randomly mask out 15% of the words in the input and replace them with a [MASK] token before passing the sequence through the BERT attention-based encoder and then predict only the masked words based on the context provided by the other non-masked words in the sequence.

While this approach solves the unidirectional constraint, a downside is that there is a mismatch between pre-training and fine-tuning, since the [MASK] token does not appear during fine-tuning. Hence, during masking, the following is done:

  • 80% of the tokens are actually replaced with the token [MASK].
  • 10% of the time, tokens are replaced with a random token.
  • 10% of the time, tokens are left unchanged.

We perform pre-training on sales conversation emails using the Masked LM pretraining task. The goal here is to allow the model to learn better representations for tokens present in the sales conversations context. For this task, we use a corpus of 8 million emails, ranging through conversations across different products.

This pre-training has been performed on G5.24xlarge GPU with 384 GB memory with a batch size of 32 for two epochs each. We compare the performance of the following models over the pretraining task:

  • BERT-BASE-MULTILINGUAL-UNCASED
  • FNET-BASE

These experiments are also performed with maximum token lengths of 64 and 128 to explore the convergence at multiple sentence lengths.

Validation loss refers to the measure of how well a trained model generalizes to unseen data during the validation phase. It quantifies the error or mismatch between the predicted outputs of the model and the actual expected outputs of the validation dataset. During model training, the training dataset is used to update the model’s parameters and optimize its performance. However, it is crucial to evaluate the model’s performance on data it has not seen before to assess its ability to generalize and make accurate predictions on new, unseen instances. The lower the validation loss, the better the model performance because it indicates a reduction in prediction errors.

Perplexity measures how well a language model predicts a given sequence of words. Entropy measures the average amount of information or uncertainty associated with each word in a sequence. Perplexity, on the other hand, is the exponentiation of entropy and provides a more interpretable value. Perplexity can be computed using the following formula:

 

                       Perplexity =exp(cross-entropy)                                 (6)

 

A lower perplexity indicates a better language model that is more confident and accurate in predicting the next word in a sequence. It reflects the model’s ability to capture the underlying patterns and structure of the training data.

We observe that FNETs have shown a drastic reduction in training duration, with just a minimal difference in validation loss and eval perplexity. Also, the gap between these evaluation metrics reduces as we increase the token length.

2. Deal classification based on email sentiment
We tested with a training set of 54K emails and a test set of 18K emails. We compared the performance of the following models over the fine-tuning classification (won or lost) task:

  • BERT-BASE-UNCASED
  • FNET-BASE 

This fine-tuning has been performed on G4.12xlarge GPU with 100GB memory with a batch size of 32 for four epochs each.

 

With these results, we observe that FNETs have shown a drastic reduction in training duration, with just minimal difference in validation AUC. The training time gap is starker at higher token lengths. The average inference time is calculated by measuring the individual inference times for each email and taking the mean across 18K emails. We can see the same drastic difference at higher token lengths.

3. Email phrase extraction
We tested with a training set of 200K emails and a test set of 58K emails. We compared the performance of the following models:

  • BERT-BASE-UNCASED
  • FNET-BASE 

This fine-tuning has been performed on G4.12xlarge GPU with 100GB memory with a batch size of 32 for three epochs each.

 

We used the Jaccard score to evaluate the quality of the phrases predicted by our model. The Jaccard score is the ratio of the number of common tokens between two strings and the number of unique tokens from the two strings.

  • Set A–set of unique tokens of string 1
  • Set B–set of unique tokens of string 2

With these results, we observe that FNETs have shown a drastic reduction in training duration with just minimal difference in validation Jaccard scores. The training time gap is starker at higher token lengths. Another advantage that can be seen here is that at higher token lengths, BERT-based models throw OOM errors, while FNETs do not because they are lighter.

Can Fourier transforms remove noise in text just like they do in voice signals?

We also observed that the phrases extracted from FNETs were cleaner and much more nuanced than the phrases from the BERT models. Fourier transforms basically decompose signals into their constituent frequencies and then pick the important frequencies as a means of denoising the signals. Similarly, our hypothesis was that Fourier transforms would denoise the extracted phrases. Although we do not yet have a metric to measure this capability, we can see that this hypothesis holds true for some examples. Interestingly, for the last email, the fine-tuned FNET model doesn’t extract any phrase at all, which is actually correct!

Summary

FNETs are lighter alternatives for computation-heavy BERT-based models. Self-attention layers are replaced by non-parametric Fourier transform layers. We demonstrate the drastic decrease in training and inference times with only a slight decrease across different evaluation metrics. We also observed that the FNET model was close to 40% lighter than the BERT-based models and can especially be used in pipelines implemented on small devices. The need for optimization of large language models, especially transformer-based, is growing by the day, and the usage of FNETs is a possible alternative. 

Aanchal Varma co-authored this piece. Aanchal was a senior data scientist at Freshworks. She’s experienced in solving complex problems and building scalable solutions in the fields of NLP, deep learning, language modeling, and machine learning.