Jaideep Ray
Better ML
Published in
3 min readDec 11, 2022

--

Recsys model serving: Embeddings

Context:

Production recommendation models have a large number of sparse features of high cardinality. For example, a video recommendation model will have a sparse feature for watched video ids. The cardinality of such a feature is O(billions). High cardinality resulting results in large embedding tables. This poses a challenge for both training and inference in terms of storage and efficiency bottlenecks. Embedding lookups and processing are the the most expensive operations both in terms of computation and communication in both training & inference.

Let’s see the steps with embedding lookup during training & inference:

Training :

  1. Embedding table lookup & pooling in storage devices. [Forward computation]
  2. Send to target devices such as GPUs. [Forward communication]
  3. Send back gradients of embedding vectors from target devices to devices with embedding tables [Backward communication].
  4. Apply gradient to selected embedding vectors. [Backward computation]

Inference :

During inference only the forward pass (steps 1, 2) happen. In this note we will focus on 1 & 2.

Embedding operator during training

What are the main challenges in serving embeddings?

a. Embedding table compression:

This comprises of techniques to reduce embedding tables. The two most common techniques are

  1. Pruning: Techniques to discard less informative rows in the table without impacting model quality significantly. A common technique is filtering rows which were updated / seen below a specific frequency during training. Heuristic based frequency filters are widely used. Pruning works well in recsys as the sparse features such as clicked content id, viewed videos etc generally have a long tail distribution.
  2. Quantization: Embedding vectors during training are in fp32. These can be quantized to a lower bit representation like fp16/fp8/int8/int4 during serving if there isn’t a significant impact on model performance.

b. Embedding table sharding:

For sparse features with high cardinality, the embedding table can be too large to fit in memory even after compression. One has to shard the embedding table across all ranks (devices such as GPUs). To understand the impact on embedding operator cost due to sharding, let’s consider the set of parameters as:

N: Number of devices / ranks such as GPU
B: Global batch size
D: Avg embedding dimension
M: Number of embedding tables in a model

In row wise sharding, partitions of tables are kept across all ranks. For pooling, the intermediate storage at each rank would would be:
M x B x D. So, total communication overhead in pooling would be
M x B x D x N. As you see there is pooling at two levels — at each rank and then globally.

In column wise sharding, the embeddings are split along the dimension and stored across a subset of ranks. For pooling, intermediate storage at across all ranks would be: M x D x B. In column wise sharding, there is pooling only once and then concat the results.

c. Embedding table placement:

Recommendation models have a wide range of embedding tables which they need to place in devices (GPUs) available for serving. The overall objective is to minimize the cost of communication and optimize performance.
The first one can be a simple greedy heuristic that sorts the size/cost of embedding tables/shards in a descending order and allocates the largest shard first, one per worker.
ML techniques like multi-armed bandit or reinforcement learning have also been used to solve embedding table placement.

d. Real time updates:

User data and preferences change fast (concept drift). Recsys models greatly benefit from real time model updates/refresh. Updating whole embedding tables in real time can be quite expensive. Instead of updating the whole table, only the delta to the embedding table can be sent every 10 min.

Optimizing embedding operations for recsys model training and serving is a highly impactful area and under active research and development.
Here are some great references:

  1. Dreamshard : An embedding placement framework.
  2. Software-Hardware Co-design for Fast and Scalable Training of Deep Learning Recommendation Models

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

--

--

No responses yet

Write a response