RDMA/cma: Increment port number after close to avoid re-use

Randomize the starting port number and avoid re-using port values
immediately after they are closed.  Instead keep track of the last
port value used and increment it every time a new port number is
assigned, to better replicate other port spaces.

Signed-off-by: Roland Dreier <rolandd@cisco.com>
diff --git a/drivers/infiniband/core/cma.c b/drivers/infiniband/core/cma.c
index 9e0ab04..bc31b54 100644
--- a/drivers/infiniband/core/cma.c
+++ b/drivers/infiniband/core/cma.c
@@ -71,6 +71,7 @@
 static DEFINE_IDR(sdp_ps);
 static DEFINE_IDR(tcp_ps);
 static DEFINE_IDR(udp_ps);
+static int next_port;
 
 struct cma_device {
 	struct list_head	list;
@@ -1722,33 +1723,74 @@
 			  unsigned short snum)
 {
 	struct rdma_bind_list *bind_list;
-	int port, start, ret;
+	int port, ret;
 
 	bind_list = kzalloc(sizeof *bind_list, GFP_KERNEL);
 	if (!bind_list)
 		return -ENOMEM;
 
-	start = snum ? snum : sysctl_local_port_range[0];
-
 	do {
-		ret = idr_get_new_above(ps, bind_list, start, &port);
+		ret = idr_get_new_above(ps, bind_list, snum, &port);
 	} while ((ret == -EAGAIN) && idr_pre_get(ps, GFP_KERNEL));
 
 	if (ret)
-		goto err;
+		goto err1;
 
-	if ((snum && port != snum) ||
-	    (!snum && port > sysctl_local_port_range[1])) {
-		idr_remove(ps, port);
+	if (port != snum) {
 		ret = -EADDRNOTAVAIL;
-		goto err;
+		goto err2;
 	}
 
 	bind_list->ps = ps;
 	bind_list->port = (unsigned short) port;
 	cma_bind_port(bind_list, id_priv);
 	return 0;
-err:
+err2:
+	idr_remove(ps, port);
+err1:
+	kfree(bind_list);
+	return ret;
+}
+
+static int cma_alloc_any_port(struct idr *ps, struct rdma_id_private *id_priv)
+{
+	struct rdma_bind_list *bind_list;
+	int port, ret;
+
+	bind_list = kzalloc(sizeof *bind_list, GFP_KERNEL);
+	if (!bind_list)
+		return -ENOMEM;
+
+retry:
+	do {
+		ret = idr_get_new_above(ps, bind_list, next_port, &port);
+	} while ((ret == -EAGAIN) && idr_pre_get(ps, GFP_KERNEL));
+
+	if (ret)
+		goto err1;
+
+	if (port > sysctl_local_port_range[1]) {
+		if (next_port != sysctl_local_port_range[0]) {
+			idr_remove(ps, port);
+			next_port = sysctl_local_port_range[0];
+			goto retry;
+		}
+		ret = -EADDRNOTAVAIL;
+		goto err2;
+	}
+
+	if (port == sysctl_local_port_range[1])
+		next_port = sysctl_local_port_range[0];
+	else
+		next_port = port + 1;
+
+	bind_list->ps = ps;
+	bind_list->port = (unsigned short) port;
+	cma_bind_port(bind_list, id_priv);
+	return 0;
+err2:
+	idr_remove(ps, port);
+err1:
 	kfree(bind_list);
 	return ret;
 }
@@ -1811,7 +1853,7 @@
 
 	mutex_lock(&lock);
 	if (cma_any_port(&id_priv->id.route.addr.src_addr))
-		ret = cma_alloc_port(ps, id_priv, 0);
+		ret = cma_alloc_any_port(ps, id_priv);
 	else
 		ret = cma_use_port(ps, id_priv);
 	mutex_unlock(&lock);
@@ -2448,6 +2490,10 @@
 {
 	int ret;
 
+	get_random_bytes(&next_port, sizeof next_port);
+	next_port = (next_port % (sysctl_local_port_range[1] -
+				  sysctl_local_port_range[0])) +
+		    sysctl_local_port_range[0];
 	cma_wq = create_singlethread_workqueue("rdma_cm_wq");
 	if (!cma_wq)
 		return -ENOMEM;