SUNRPC: Add basic load balancing to the transport switch

For now, just count the queue length. It is less accurate than counting
number of bytes queued, but easier to implement.

Signed-off-by: Trond Myklebust <trond.myklebust@primarydata.com>
diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c
index b03bfa0..976eab68 100644
--- a/net/sunrpc/clnt.c
+++ b/net/sunrpc/clnt.c
@@ -968,13 +968,47 @@ struct rpc_clnt *rpc_bind_new_program(struct rpc_clnt *old,
 }
 EXPORT_SYMBOL_GPL(rpc_bind_new_program);
 
+static struct rpc_xprt *
+rpc_task_get_xprt(struct rpc_clnt *clnt)
+{
+	struct rpc_xprt_switch *xps;
+	struct rpc_xprt *xprt= xprt_iter_get_next(&clnt->cl_xpi);
+
+	if (!xprt)
+		return NULL;
+	rcu_read_lock();
+	xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch);
+	atomic_long_inc(&xps->xps_queuelen);
+	rcu_read_unlock();
+	atomic_long_inc(&xprt->queuelen);
+
+	return xprt;
+}
+
+static void
+rpc_task_release_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt)
+{
+	struct rpc_xprt_switch *xps;
+
+	atomic_long_dec(&xprt->queuelen);
+	rcu_read_lock();
+	xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch);
+	atomic_long_dec(&xps->xps_queuelen);
+	rcu_read_unlock();
+
+	xprt_put(xprt);
+}
+
 void rpc_task_release_transport(struct rpc_task *task)
 {
 	struct rpc_xprt *xprt = task->tk_xprt;
 
 	if (xprt) {
 		task->tk_xprt = NULL;
-		xprt_put(xprt);
+		if (task->tk_client)
+			rpc_task_release_xprt(task->tk_client, xprt);
+		else
+			xprt_put(xprt);
 	}
 }
 EXPORT_SYMBOL_GPL(rpc_task_release_transport);
@@ -983,6 +1017,7 @@ void rpc_task_release_client(struct rpc_task *task)
 {
 	struct rpc_clnt *clnt = task->tk_client;
 
+	rpc_task_release_transport(task);
 	if (clnt != NULL) {
 		/* Remove from client task list */
 		spin_lock(&clnt->cl_lock);
@@ -992,14 +1027,13 @@ void rpc_task_release_client(struct rpc_task *task)
 
 		rpc_release_client(clnt);
 	}
-	rpc_task_release_transport(task);
 }
 
 static
 void rpc_task_set_transport(struct rpc_task *task, struct rpc_clnt *clnt)
 {
 	if (!task->tk_xprt)
-		task->tk_xprt = xprt_iter_get_next(&clnt->cl_xpi);
+		task->tk_xprt = rpc_task_get_xprt(clnt);
 }
 
 static