blk-mq: manage hctx map via xarray

First code becomes more clean by switching to xarray from plain array.

Second use-after-free on q->queue_hw_ctx can be fixed because
queue_for_each_hw_ctx() may be run when updating nr_hw_queues is
in-progress. With this patch, q->hctx_table is defined as xarray, and
this structure will share same lifetime with request queue, so
queue_for_each_hw_ctx() can use q->hctx_table to lookup hctx reliably.

Reported-by: Yu Kuai <yukuai3@huawei.com>
Signed-off-by: Ming Lei <ming.lei@redhat.com>
Reviewed-by: Hannes Reinecke <hare@suse.de>
Reviewed-by: Christoph Hellwig <hch@lst.de>
Link: https://lore.kernel.org/r/20220308073219.91173-7-ming.lei@redhat.com
[axboe: fix blk_mq_hw_ctx forward declaration]
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/block/blk-mq.c b/block/blk-mq.c
index bffdd71..e8d6d7a 100644
--- a/block/blk-mq.c
+++ b/block/blk-mq.c
@@ -71,7 +71,8 @@ static int blk_mq_poll_stats_bkt(const struct request *rq)
 static inline struct blk_mq_hw_ctx *blk_qc_to_hctx(struct request_queue *q,
 		blk_qc_t qc)
 {
-	return q->queue_hw_ctx[(qc & ~BLK_QC_T_INTERNAL) >> BLK_QC_T_SHIFT];
+	return xa_load(&q->hctx_table,
+			(qc & ~BLK_QC_T_INTERNAL) >> BLK_QC_T_SHIFT);
 }
 
 static inline struct request *blk_qc_to_rq(struct blk_mq_hw_ctx *hctx,
@@ -573,7 +574,7 @@ struct request *blk_mq_alloc_request_hctx(struct request_queue *q,
 	 * If not tell the caller that it should skip this queue.
 	 */
 	ret = -EXDEV;
-	data.hctx = q->queue_hw_ctx[hctx_idx];
+	data.hctx = xa_load(&q->hctx_table, hctx_idx);
 	if (!blk_mq_hw_queue_mapped(data.hctx))
 		goto out_queue_exit;
 	cpu = cpumask_first_and(data.hctx->cpumask, cpu_online_mask);
@@ -3437,6 +3438,8 @@ static void blk_mq_exit_hctx(struct request_queue *q,
 
 	blk_mq_remove_cpuhp(hctx);
 
+	xa_erase(&q->hctx_table, hctx_idx);
+
 	spin_lock(&q->unused_hctx_lock);
 	list_add(&hctx->hctx_list, &q->unused_hctx_list);
 	spin_unlock(&q->unused_hctx_lock);
@@ -3476,8 +3479,15 @@ static int blk_mq_init_hctx(struct request_queue *q,
 	if (blk_mq_init_request(set, hctx->fq->flush_rq, hctx_idx,
 				hctx->numa_node))
 		goto exit_hctx;
+
+	if (xa_insert(&q->hctx_table, hctx_idx, hctx, GFP_KERNEL))
+		goto exit_flush_rq;
+
 	return 0;
 
+ exit_flush_rq:
+	if (set->ops->exit_request)
+		set->ops->exit_request(set, hctx->fq->flush_rq, hctx_idx);
  exit_hctx:
 	if (set->ops->exit_hctx)
 		set->ops->exit_hctx(hctx, hctx_idx);
@@ -3856,7 +3866,7 @@ void blk_mq_release(struct request_queue *q)
 		kobject_put(&hctx->kobj);
 	}
 
-	kfree(q->queue_hw_ctx);
+	xa_destroy(&q->hctx_table);
 
 	/*
 	 * release .mq_kobj and sw queue's kobject now because
@@ -3945,46 +3955,28 @@ static struct blk_mq_hw_ctx *blk_mq_alloc_and_init_hctx(
 static void blk_mq_realloc_hw_ctxs(struct blk_mq_tag_set *set,
 						struct request_queue *q)
 {
-	int i, j, end;
-	struct blk_mq_hw_ctx **hctxs = q->queue_hw_ctx;
-
-	if (q->nr_hw_queues < set->nr_hw_queues) {
-		struct blk_mq_hw_ctx **new_hctxs;
-
-		new_hctxs = kcalloc_node(set->nr_hw_queues,
-				       sizeof(*new_hctxs), GFP_KERNEL,
-				       set->numa_node);
-		if (!new_hctxs)
-			return;
-		if (hctxs)
-			memcpy(new_hctxs, hctxs, q->nr_hw_queues *
-			       sizeof(*hctxs));
-		q->queue_hw_ctx = new_hctxs;
-		kfree(hctxs);
-		hctxs = new_hctxs;
-	}
+	struct blk_mq_hw_ctx *hctx;
+	unsigned long i, j;
 
 	/* protect against switching io scheduler  */
 	mutex_lock(&q->sysfs_lock);
 	for (i = 0; i < set->nr_hw_queues; i++) {
 		int old_node;
 		int node = blk_mq_get_hctx_node(set, i);
-		struct blk_mq_hw_ctx *old_hctx = hctxs[i];
+		struct blk_mq_hw_ctx *old_hctx = xa_load(&q->hctx_table, i);
 
 		if (old_hctx) {
 			old_node = old_hctx->numa_node;
 			blk_mq_exit_hctx(q, set, old_hctx, i);
 		}
 
-		hctxs[i] = blk_mq_alloc_and_init_hctx(set, q, i, node);
-		if (!hctxs[i]) {
+		if (!blk_mq_alloc_and_init_hctx(set, q, i, node)) {
 			if (!old_hctx)
 				break;
 			pr_warn("Allocate new hctx on node %d fails, fallback to previous one on node %d\n",
 					node, old_node);
-			hctxs[i] = blk_mq_alloc_and_init_hctx(set, q, i,
-					old_node);
-			WARN_ON_ONCE(!hctxs[i]);
+			hctx = blk_mq_alloc_and_init_hctx(set, q, i, old_node);
+			WARN_ON_ONCE(!hctx);
 		}
 	}
 	/*
@@ -3993,21 +3985,13 @@ static void blk_mq_realloc_hw_ctxs(struct blk_mq_tag_set *set,
 	 */
 	if (i != set->nr_hw_queues) {
 		j = q->nr_hw_queues;
-		end = i;
 	} else {
 		j = i;
-		end = q->nr_hw_queues;
 		q->nr_hw_queues = set->nr_hw_queues;
 	}
 
-	for (; j < end; j++) {
-		struct blk_mq_hw_ctx *hctx = hctxs[j];
-
-		if (hctx) {
-			blk_mq_exit_hctx(q, set, hctx, j);
-			hctxs[j] = NULL;
-		}
-	}
+	xa_for_each_start(&q->hctx_table, j, hctx, j)
+		blk_mq_exit_hctx(q, set, hctx, j);
 	mutex_unlock(&q->sysfs_lock);
 }
 
@@ -4046,6 +4030,8 @@ int blk_mq_init_allocated_queue(struct blk_mq_tag_set *set,
 	INIT_LIST_HEAD(&q->unused_hctx_list);
 	spin_lock_init(&q->unused_hctx_lock);
 
+	xa_init(&q->hctx_table);
+
 	blk_mq_realloc_hw_ctxs(set, q);
 	if (!q->nr_hw_queues)
 		goto err_hctxs;
@@ -4075,7 +4061,7 @@ int blk_mq_init_allocated_queue(struct blk_mq_tag_set *set,
 	return 0;
 
 err_hctxs:
-	kfree(q->queue_hw_ctx);
+	xa_destroy(&q->hctx_table);
 	q->nr_hw_queues = 0;
 	blk_mq_sysfs_deinit(q);
 err_poll: