How multi-task learning helped us create better AI

Motivation

Nature has been an inspiration and also a benchmark for Artificial Intelligence. The gold test of intelligence in an AI system is its ability to fool a human being

Acquisition of intelligence is multi-modal in the natural world. For example, children learn to recognize shapes and learn their names in a simultaneous learning process. In nature, as in AI, divorcing the learning tasks leads to poor outcomes. 

Applied AI has also borrowed ideas from nature to improve the systems we develop. One such paradigm is called Multitask learning (MTL). MTL is a process where we train Machine Learning (ML) models on multiple tasks simultaneously. MTL allows the models to learn a shared representation; learning ideas that are common across all the tasks. In this article, we describe the re-implementation of one of our AI features using Multi-Task Learning. 

Context: Ticket Classification 

Freshservice is a SaaS ITSM product with thousands of customers. Employees of our customers raise a ticket for their IT services on Freshservice. When an IT agent opens a ticket, she sees a screen similar to the screenshot below. The left side of the screen has details about the specific request, and the right side of the screen shows the properties of the ticket and the requester’s information. 

Our customers’ agents use the ticket properties of Priority, Urgency, Impact, Category, Sub-Category, Item-Category and Group to make important decisions about the processing of the ticket. For example, they might assign the ticket a “Category” of Hardware so that the ticket is routed to a specialised agent. Or, they might assign “High” priority to a ticket that needs urgent attention. 

In the first generation of our product, an agent would manually update each ticket field based on her understanding of the ticket. In a later release of the product, machine learning models predicted the values of the fields. We have 7 fields in Ticket Properties – priority, urgency, impact, category, sub-category, item-category & Group. 

We trained 7 independent fastText models on the previous 3 to 6 months of data and made 7 independent predictions for each of these fields. 

Multi-Task Learning

The latest version of Freshservice has models trained using multi-task learning approaches. As detailed in section Performance Evaluation, this approach led to significant improvements over the fastText models used previously, outperforming each of the 7 independent tasks. 

Caruana in a 1997 paper described Multi-task learning (MTL) as a set of ideas that improves generalization by leveraging the domain-specific information contained in the training signals of related tasks. It does this by training tasks in parallel while using a shared representation. In effect, the training signals for the extra tasks serve as an inductive bias. Inductive Bias, a frequent problem in ML, denotes the bias induced in the Machine Learning system because of assumptions we make about the unseen data. All learning algorithms carry this bias implicitly. As an example, Occam’s razor that says that simplest explanation is the best explanation is an inductive bias.

MTL reduces the problem associated with Inductive Bias by increasing the generalisation of the learned machine learning model. Seen in this way, MTL is a regularisation method. Consider the tasks of categorising a ticket into “Hardware” or “Security” and categorising the ticket urgency into “High” or “Medium”. In most organisations, a security ticket will get a higher priority than hardware. Intuitively, it would seem that these tasks must share some information among themselves. However, treating these tasks independently does not allow us to share any information between tasks. An MTL approach overcomes this problem by learning a “Shared Representation” of the data. 

Model Architecture & Training

Multi-task learning can be achieved by one of two methods – Hard Parameter Sharing or Soft Parameter Sharing. Hard Parameter Sharing, as shown in the figure below, has hidden layers shared across the tasks. The hard parameter sharing reduces the risk of over-fitting on the individual tasks, as the shared layers have to learn a representation that captures information about all the tasks. Intuitively, it would seem that the more tasks we have, the lower the chances of overfitting in any of the tasks.

The second approach to MTL is soft parameter sharing, where the shared layers learn their parameters independently. However, the parameters are constrained, so that they have similar values. The most common techniques used are l2norm or trace norm. 

For our project we used hard parameter sharing on a DistilBERT model. In this architecture, the DistilBERT layers are shared by all the downstream tasks and thus learn a shared representation that is general across all the seven tasks. The task-specific layers are deliberately lightweight consisting of a linear layer, followed by a RELU and Softmax layer. 

Input to the model

We experimented with 2 strategies to feed input data to the network as shown in the figure below. In the first strategy we concatenated the subject and the description of tickets and trained a single DistilBERT block. This strategy is simple to implement and reduces the training time and costs since the model has to learn a lesser number of parameters. The alternative is to pass the subject and the description through 2 separate DistilBERT layers and concatenate the output layer of DistilBERT. This network is more complex as the gradient is back-propagated to two DistilBERT layers. The training time, cost and inference time and costs are higher for option 2.  

We chose option 2 as our final model as it performed better on  customer-facing metrics such as accuracy.

Model Training

The overall multitask loss is the ordinary sum of the individual cross entropy losses. We used AdamW to optimise the network parameters with differential learning rate (LR) and weight decay for the task specific layers and the DistilBERT layers. The task layer had an LR of 5e-2 and the DistilBERT had a LR of 3e-5. The corresponding weight decays were 0.1 and 0.05. We determined that a batch size of 16 was the best for us using a systematic hyper-parameter optimization strategy. The data was highly imbalanced for many of our accounts and therefore we weighted our cross entropy loss with a class weight factor.

We developed the training data pipeline and trained these models on the AWS manager service –  AWS Sagemaker.  Time to train a model for one customer with an average dataset size of about 100,000 records was about 4-5 hours. We trained for 15 accounts in parallel by allowing Sagemaker to start upto 15 instances of P2 type virtual machines.      

Performance Evaluation 

The performance of the MTL models was significantly better than the earlier fastText based models. We are presenting the results from the final model we released to our customers. The analysis is based on the improvements for 73 customers, who were randomly selected as a benchmark for our experiments. 

We evaluated the performance improvements using the macro-F1 score computed for 3 months held as test data. The classifiers were trained on one year’s ticket data for each account. The histograms (figure below) represent the absolute improvements of the macro F1 of MTL over fastText across the 73 benchmark customers. As we see in the graphs below, all the fields have a right skew, demonstrating that most fields for most customers benefit from the new MTL models. For example, the graph of ‘Group’ shows that for all but two customers the macro F1 score improved on using MTL. 

We further evaluate the aggregate improvements in F1 score across all accounts. We present the summary in the table below. Group prediction had the most improvements, with a median improvement of 0.21 in F1 score. Category and Sub-Category had improvements of 0.17 and 0.12, respectively. Impact and Item-Subcategory did not show any improvements and the median improvement was 0. Impact and Item-Subcategory are highly imbalanced as most of these tickets have Null values.

Conclusion

Research scientist and blogger on machine learning Sebastian Ruder has written a great post on why MTL works at all. Some of the key benefits that we realised using MTL are: 

  • Implicit Data Augmentation: All of the 7 tasks have some amount of noise. But when the model is forced to learn a joint representation, the individual noise averages out. Thus, the final representation is relatively noise free. This approach is like the use of random subspace models (e.g. Random Forest) that train an ensemble of models where each model trained on a subspace of features. 
  • Attention Focusing: For high-dimensional data the model cannot effectively recognise the important features from the non-important ones. MTL can act as an attention mechanism so that the model learns to focus only on important features. 
  • Eavesdropping: One task can learn to model some information better than the other tasks. In MTL, both tasks can benefit from the information learned by the other task through eavesdropping. 
  • Representation Bias: MTL forces the model to learn representations that work for ALL the tasks. This helps the models generalise better and work well for novel data and tasks. 
  • Regularisation: Finally, MTL acts as a higher level regularisation technique by introducing an inductive bias. 

Multi-task learning learns a general model that in our case outperforms the individual specialised models by a large margin.

Saurav Chakravorty, Director-Data Science at Freshworks, co-authored this post.

Cover image: Vignesh Rajan