Classification with TabNet: Deep Dive

Antons Tocilins-Ruberts
Ravelin Tech Blog
Published in
7 min readJan 14, 2022

--

Photo by Mika Baumeister on Unsplash

Tabular data is the bread and butter for training fraud detection algorithms at Ravelin. We extract transaction, identity, product, and network attributes (read this blog if you’re interested in our network features) and place them into a big table of features which can be easily used by different Machine Learning models for training and inference. Decision tree based models (e.g Random Forest or XGBoost) are the go-to algorithms for dealing with tabular data because of their performance, interpretability, speed of training and robustness.

On the other hand, neural networks are considered state of the art in many fields and perform particularly well on large datasets with minimal feature engineering. Many of our clients have large transaction volumes and deep learning is a potential path towards improving model performance in terms of fraud detection.

In this blog, we’re going to deep dive into the neural network architecture called TabNet (Arik & Pfister (2019)) which was designed to be interpretable and to work well with tabular data. After explaining the key building blocks and ideas behind it, you’ll see how to implement it in TensorFlow and how it can be applied to the fraud detection dataset. Most of the code was taken from this implementation, so make sure to leave a star there if you use it in your work!

TabNet

TabNet mimics the behaviour of decision trees using the idea of Sequential Attention. Simplistically speaking, you can think of it as a multi-step neural network that applies two key operations at each step:

  1. An Attentive Transformer selects the most important features to process at the next step
  2. A Feature Transformer processes the features into a more useful representation

The output of the Feature Transformer is later used in the prediction. Using both the Attentive and Feature Transformers, TabNet is able to simulate the decision making process of tree-based models. Consider the high-level overview of model’s prediction on the Adult Census Income dataset below. The model is able to select and process features which are the most useful for the task at hand which improves interpretability and learning.

TabNet prediction intuition on Income dataset. Source: https://arxiv.org/pdf/1908.07442.pdf

The key building block for both Attentive and Feature Transformers are the so called Feature Blocks, so let’s explore them first.

Feature Blocks

Feature Blocks consist out of sequentially applied Fully-Connected (FC) (or Dense) layer and Batch Normalisation (BN). In addition, for Feature Transformers the output gets passed through the GLU activation layer.

TabNet block with GLU

The main function of the GLU (as opposed to a sigmoid gate), is to allow hidden units to propagate deeper into the model and prevent the exploding or vanishing gradients.

In addition, the original paper uses Ghost Batch Normalisation to improve the convergence speed during training. You can find Tensorflow implementation here if you are interested, but we’ll use a default Batch Normalisation layer in this tutorial.

Feature Transformers

A FeatureTransformer (FT) is basically a collection of feature blocks applied sequentially. In the papers, one FeatureTransformer consists of two shared blocks (i.e. weights are reused across steps) and two step dependent blocks. Shared weights reduce the number of parameters in the model and lead to better generalisation.

Feature Transformers

With the implementation of Feature Block in the previous section in mind, here’s how you would use it to build a Feature Transformer.

Attentive Transformer

The Attentive Transformer (AT) is responsible for feature selection at each step. The feature selection is done by applying sparsemax activation (instead of GLU) while taking into account the prior scales. A prior scale allows us to control how frequently a feature can be selected by the model and is controlled by how frequently it was used in the previous steps (more about it later).

Attentive Transformers

The previous Attention Transformer feeds into the prior scales passing along information about which features were used in the previous step. Similar to the Feature Transformer, the Attentive Transformer can be implemented as a TensorFlow model which later will be integrated into a larger architecture.

Since Feature and Attentive Transformer blocks can become quite parameter-heavy, TabNet uses a few mechanisms to control the complexity and prevent overfitting.

Regularisation

Prior Scales Calculation

Prior scales (P) allow us to control how frequently a feature can be selected by the model. Prior scales (P) are calculated using the previous Attentive Transformer activations and the relaxation factor (γ) parameter. Here’s the formula presented in the paper.

This equation shows how the prior scales get updated. The update is a product over all the steps up to the current step i. Intuitively, if a feature was used in the previous steps, the model will pay more attention to the remaining features to reduce overfitting.

For example, with γ=1, the features with large multiplicative activations (e.g. 0.9) will have small prior scales (1–0.9=0.1). Small prior scales ensure that the feature doesn’t get selected in the current step.

Sparsity regularisation

Sparsity regularisation for loss is used to encourage the attention masks to be sparse. The entropy of activations scaled by the hyperparameter λ gets added to the overall model loss.

Next let’s find out how these components are used to build a TabNet model.

TabNet Architecture

Putting It All Together

The main idea behind TabNet is that the Feature and Attentive Transformers components are applied sequentially which allows the model to mimic the decision tree making process. The Attentive Transformer performs feature selection and the Feature Transformer performs transformations that allow the model to learn complex patterns in the data. Below you can see a diagram that summarises the data flow for a 2 step TabNet model.

2-step TabNet Architecture

To begin with, we pass initial input features through the Feature Transformer to get our initial feature representations. The output of this Feature Transformer is then used as an input to the Attentive Transformer which selects a subset of features to pass to the next step. This process is repeated for the required number of steps. You can see a TensorFlow implementation using the classes defined above in this code snippet.

The model produces final predictions by using the Feature Transformer outputs from each decision step. In addition, we can aggregate the attention masks at each step to understand which features were used to make a prediction. These masks can be used to obtain local feature importances as well as global importances.

Congratulations! Now you understand what TabNet does. Let’s see how to train this model on the example fraud detection dataset from Kaggle.

Fraud Detection with TabNet

Data

The data that we’ll use in this blog can be found in this Kaggle competition. Not all the code will be shown below, so please make sure to check this notebook with full working example. The notebook pre-processes the data, trains a TabNet model and evaluates it. You can do a lot more feature engineering, parameter-tuning and pre-processing, so consider it as a starting point.

The training dataset is quite large with ~590k observations and 420 features. The pre-processing performed is very basic since it’s not the goal of this blog. Here’s a small summary of steps which you’ll find in the notebook:

  • Non-informative columns are dropped
  • Missing values are imputed
  • Categorical variables are encoded
  • Time-based train/validation split

Make sure to follow the code and instructions in the notebook to get the data prepared for TabNet. The resulting inputs into the model should be TF Data objects with all numeric features and binary target.

Hyperparameter Tuning

TabNet (as any neural network) is very sensitive to the hyperparameters, so tuning is essential to get a good model. Here’re the variables (and suggested ranges) which we found have the largest impact on the model’s performance:

  • Feature dimensions/output dimensions: from 32 to 512 (we usually set these parameters equal to each other, as suggested by the paper)
  • Number of steps: from 2 (simple model) to 9 (very complex model)
  • Relaxation factor: from 1 (enforces features to be used only at 1 step) to 3 (relaxes the restriction)
  • Sparsity coefficient: from 0 (no regularisation) to 0.1 (hard regularisation)

A simple example of HP tuning is also given in the notebook.

Training and Evaluation

Training TabNet is similar to any other Keras model. You can further improve the performance by experimenting with learning rate scheduling and decay.

We usually evaluate our models using the ROC and PR AUC scores (because the target is unbalanced), so here are the metrics on the validation set for this model.

Test ROC AUC 0.8505
Test PR AUC 0.464

Not bad but not great either. You wouldn’t get far with these results in the leaderboard, but at least now you know how to build and train the TabNet model.

Conclusion

Great job on getting this far! By now you should know what is happening under the hood of TabNet and how it uses Attentive and Feature Transformers to make predictions. There’s a lot more to TabNet that we haven’t covered (e.g. Attention Masks or self-supervised pre-training), so let us know if you’d like further exploration of this model. Feel free to use any code shared here or in the notebook for your own deep learning projects.

--

--