sbitmap: fix missed wakeups caused by sbitmap_queue_get_shallow()

The sbitmap queue wake batch is calculated such that once allocations
start blocking, all of the bits which are already allocated must be
enough to fulfill the batch counters of all of the waitqueues. However,
the shallow allocation depth can break this invariant, since we block
before our full depth is being utilized. Add
sbitmap_queue_min_shallow_depth(), which saves the minimum shallow depth
the sbq will use, and update sbq_calc_wake_batch() to take it into
account.

Acked-by: Paolo Valente <paolo.valente@linaro.org>
Signed-off-by: Omar Sandoval <osandov@fb.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/lib/sbitmap.c b/lib/sbitmap.c
index e6a9c06e..d21473b 100644
--- a/lib/sbitmap.c
+++ b/lib/sbitmap.c
@@ -270,18 +270,33 @@ void sbitmap_bitmap_show(struct sbitmap *sb, struct seq_file *m)
 }
 EXPORT_SYMBOL_GPL(sbitmap_bitmap_show);
 
-static unsigned int sbq_calc_wake_batch(unsigned int depth)
+static unsigned int sbq_calc_wake_batch(struct sbitmap_queue *sbq,
+					unsigned int depth)
 {
 	unsigned int wake_batch;
+	unsigned int shallow_depth;
 
 	/*
 	 * For each batch, we wake up one queue. We need to make sure that our
-	 * batch size is small enough that the full depth of the bitmap is
-	 * enough to wake up all of the queues.
+	 * batch size is small enough that the full depth of the bitmap,
+	 * potentially limited by a shallow depth, is enough to wake up all of
+	 * the queues.
+	 *
+	 * Each full word of the bitmap has bits_per_word bits, and there might
+	 * be a partial word. There are depth / bits_per_word full words and
+	 * depth % bits_per_word bits left over. In bitwise arithmetic:
+	 *
+	 * bits_per_word = 1 << shift
+	 * depth / bits_per_word = depth >> shift
+	 * depth % bits_per_word = depth & ((1 << shift) - 1)
+	 *
+	 * Each word can be limited to sbq->min_shallow_depth bits.
 	 */
-	wake_batch = SBQ_WAKE_BATCH;
-	if (wake_batch > depth / SBQ_WAIT_QUEUES)
-		wake_batch = max(1U, depth / SBQ_WAIT_QUEUES);
+	shallow_depth = min(1U << sbq->sb.shift, sbq->min_shallow_depth);
+	depth = ((depth >> sbq->sb.shift) * shallow_depth +
+		 min(depth & ((1U << sbq->sb.shift) - 1), shallow_depth));
+	wake_batch = clamp_t(unsigned int, depth / SBQ_WAIT_QUEUES, 1,
+			     SBQ_WAKE_BATCH);
 
 	return wake_batch;
 }
@@ -307,7 +322,8 @@ int sbitmap_queue_init_node(struct sbitmap_queue *sbq, unsigned int depth,
 			*per_cpu_ptr(sbq->alloc_hint, i) = prandom_u32() % depth;
 	}
 
-	sbq->wake_batch = sbq_calc_wake_batch(depth);
+	sbq->min_shallow_depth = UINT_MAX;
+	sbq->wake_batch = sbq_calc_wake_batch(sbq, depth);
 	atomic_set(&sbq->wake_index, 0);
 
 	sbq->ws = kzalloc_node(SBQ_WAIT_QUEUES * sizeof(*sbq->ws), flags, node);
@@ -327,9 +343,10 @@ int sbitmap_queue_init_node(struct sbitmap_queue *sbq, unsigned int depth,
 }
 EXPORT_SYMBOL_GPL(sbitmap_queue_init_node);
 
-void sbitmap_queue_resize(struct sbitmap_queue *sbq, unsigned int depth)
+static void sbitmap_queue_update_wake_batch(struct sbitmap_queue *sbq,
+					    unsigned int depth)
 {
-	unsigned int wake_batch = sbq_calc_wake_batch(depth);
+	unsigned int wake_batch = sbq_calc_wake_batch(sbq, depth);
 	int i;
 
 	if (sbq->wake_batch != wake_batch) {
@@ -342,6 +359,11 @@ void sbitmap_queue_resize(struct sbitmap_queue *sbq, unsigned int depth)
 		for (i = 0; i < SBQ_WAIT_QUEUES; i++)
 			atomic_set(&sbq->ws[i].wait_cnt, 1);
 	}
+}
+
+void sbitmap_queue_resize(struct sbitmap_queue *sbq, unsigned int depth)
+{
+	sbitmap_queue_update_wake_batch(sbq, depth);
 	sbitmap_resize(&sbq->sb, depth);
 }
 EXPORT_SYMBOL_GPL(sbitmap_queue_resize);
@@ -403,6 +425,14 @@ int __sbitmap_queue_get_shallow(struct sbitmap_queue *sbq,
 }
 EXPORT_SYMBOL_GPL(__sbitmap_queue_get_shallow);
 
+void sbitmap_queue_min_shallow_depth(struct sbitmap_queue *sbq,
+				     unsigned int min_shallow_depth)
+{
+	sbq->min_shallow_depth = min_shallow_depth;
+	sbitmap_queue_update_wake_batch(sbq, sbq->sb.depth);
+}
+EXPORT_SYMBOL_GPL(sbitmap_queue_min_shallow_depth);
+
 static struct sbq_wait_state *sbq_wake_ptr(struct sbitmap_queue *sbq)
 {
 	int i, wake_index;
@@ -528,5 +558,6 @@ void sbitmap_queue_show(struct sbitmap_queue *sbq, struct seq_file *m)
 	seq_puts(m, "}\n");
 
 	seq_printf(m, "round_robin=%d\n", sbq->round_robin);
+	seq_printf(m, "min_shallow_depth=%u\n", sbq->min_shallow_depth);
 }
 EXPORT_SYMBOL_GPL(sbitmap_queue_show);