af_rxrpc: Keep rxrpc_call pointers in a hashtable

Keep track of rxrpc_call structures in a hashtable so they can be
found directly from the network parameters which define the call.

This allows incoming packets to be routed directly to a call without walking
through hierarchy of peer -> transport -> connection -> call and all the
spinlocks that that entailed.

Signed-off-by: Tim Smith <tim@electronghost.co.uk>
Signed-off-by: David Howells <dhowells@redhat.com>
diff --git a/net/rxrpc/ar-call.c b/net/rxrpc/ar-call.c
index 6e4d58c..a9e05db 100644
--- a/net/rxrpc/ar-call.c
+++ b/net/rxrpc/ar-call.c
@@ -12,6 +12,8 @@
 #include <linux/slab.h>
 #include <linux/module.h>
 #include <linux/circ_buf.h>
+#include <linux/hashtable.h>
+#include <linux/spinlock_types.h>
 #include <net/sock.h>
 #include <net/af_rxrpc.h>
 #include "ar-internal.h"
@@ -55,6 +57,145 @@
 static void rxrpc_ack_time_expired(unsigned long _call);
 static void rxrpc_resend_time_expired(unsigned long _call);
 
+static DEFINE_SPINLOCK(rxrpc_call_hash_lock);
+static DEFINE_HASHTABLE(rxrpc_call_hash, 10);
+
+/*
+ * Hash function for rxrpc_call_hash
+ */
+static unsigned long rxrpc_call_hashfunc(
+	u8		clientflag,
+	__be32		cid,
+	__be32		call_id,
+	__be32		epoch,
+	__be16		service_id,
+	sa_family_t	proto,
+	void		*localptr,
+	unsigned int	addr_size,
+	const u8	*peer_addr)
+{
+	const u16 *p;
+	unsigned int i;
+	unsigned long key;
+	u32 hcid = ntohl(cid);
+
+	_enter("");
+
+	key = (unsigned long)localptr;
+	/* We just want to add up the __be32 values, so forcing the
+	 * cast should be okay.
+	 */
+	key += (__force u32)epoch;
+	key += (__force u16)service_id;
+	key += (__force u32)call_id;
+	key += (hcid & RXRPC_CIDMASK) >> RXRPC_CIDSHIFT;
+	key += hcid & RXRPC_CHANNELMASK;
+	key += clientflag;
+	key += proto;
+	/* Step through the peer address in 16-bit portions for speed */
+	for (i = 0, p = (const u16 *)peer_addr; i < addr_size >> 1; i++, p++)
+		key += *p;
+	_leave(" key = 0x%lx", key);
+	return key;
+}
+
+/*
+ * Add a call to the hashtable
+ */
+static void rxrpc_call_hash_add(struct rxrpc_call *call)
+{
+	unsigned long key;
+	unsigned int addr_size = 0;
+
+	_enter("");
+	switch (call->proto) {
+	case AF_INET:
+		addr_size = sizeof(call->peer_ip.ipv4_addr);
+		break;
+	case AF_INET6:
+		addr_size = sizeof(call->peer_ip.ipv6_addr);
+		break;
+	default:
+		break;
+	}
+	key = rxrpc_call_hashfunc(call->in_clientflag, call->cid,
+				  call->call_id, call->epoch,
+				  call->service_id, call->proto,
+				  call->conn->trans->local, addr_size,
+				  call->peer_ip.ipv6_addr);
+	/* Store the full key in the call */
+	call->hash_key = key;
+	spin_lock(&rxrpc_call_hash_lock);
+	hash_add_rcu(rxrpc_call_hash, &call->hash_node, key);
+	spin_unlock(&rxrpc_call_hash_lock);
+	_leave("");
+}
+
+/*
+ * Remove a call from the hashtable
+ */
+static void rxrpc_call_hash_del(struct rxrpc_call *call)
+{
+	_enter("");
+	spin_lock(&rxrpc_call_hash_lock);
+	hash_del_rcu(&call->hash_node);
+	spin_unlock(&rxrpc_call_hash_lock);
+	_leave("");
+}
+
+/*
+ * Find a call in the hashtable and return it, or NULL if it
+ * isn't there.
+ */
+struct rxrpc_call *rxrpc_find_call_hash(
+	u8		clientflag,
+	__be32		cid,
+	__be32		call_id,
+	__be32		epoch,
+	__be16		service_id,
+	void		*localptr,
+	sa_family_t	proto,
+	const u8	*peer_addr)
+{
+	unsigned long key;
+	unsigned int addr_size = 0;
+	struct rxrpc_call *call = NULL;
+	struct rxrpc_call *ret = NULL;
+
+	_enter("");
+	switch (proto) {
+	case AF_INET:
+		addr_size = sizeof(call->peer_ip.ipv4_addr);
+		break;
+	case AF_INET6:
+		addr_size = sizeof(call->peer_ip.ipv6_addr);
+		break;
+	default:
+		break;
+	}
+
+	key = rxrpc_call_hashfunc(clientflag, cid, call_id, epoch,
+				  service_id, proto, localptr, addr_size,
+				  peer_addr);
+	hash_for_each_possible_rcu(rxrpc_call_hash, call, hash_node, key) {
+		if (call->hash_key == key &&
+		    call->call_id == call_id &&
+		    call->cid == cid &&
+		    call->in_clientflag == clientflag &&
+		    call->service_id == service_id &&
+		    call->proto == proto &&
+		    call->local == localptr &&
+		    memcmp(call->peer_ip.ipv6_addr, peer_addr,
+			      addr_size) == 0 &&
+		    call->epoch == epoch) {
+			ret = call;
+			break;
+		}
+	}
+	_leave(" = %p", ret);
+	return ret;
+}
+
 /*
  * allocate a new call
  */
@@ -136,6 +277,26 @@
 		return ERR_PTR(ret);
 	}
 
+	/* Record copies of information for hashtable lookup */
+	call->proto = rx->proto;
+	call->local = trans->local;
+	switch (call->proto) {
+	case AF_INET:
+		call->peer_ip.ipv4_addr =
+			trans->peer->srx.transport.sin.sin_addr.s_addr;
+		break;
+	case AF_INET6:
+		memcpy(call->peer_ip.ipv6_addr,
+		       trans->peer->srx.transport.sin6.sin6_addr.in6_u.u6_addr8,
+		       sizeof(call->peer_ip.ipv6_addr));
+		break;
+	}
+	call->epoch = call->conn->epoch;
+	call->service_id = call->conn->service_id;
+	call->in_clientflag = call->conn->in_clientflag;
+	/* Add the new call to the hashtable */
+	rxrpc_call_hash_add(call);
+
 	spin_lock(&call->conn->trans->peer->lock);
 	list_add(&call->error_link, &call->conn->trans->peer->error_targets);
 	spin_unlock(&call->conn->trans->peer->lock);
@@ -328,9 +489,12 @@
 		parent = *p;
 		call = rb_entry(parent, struct rxrpc_call, conn_node);
 
-		if (call_id < call->call_id)
+		/* The tree is sorted in order of the __be32 value without
+		 * turning it into host order.
+		 */
+		if ((__force u32)call_id < (__force u32)call->call_id)
 			p = &(*p)->rb_left;
-		else if (call_id > call->call_id)
+		else if ((__force u32)call_id > (__force u32)call->call_id)
 			p = &(*p)->rb_right;
 		else
 			goto old_call;
@@ -355,6 +519,28 @@
 	list_add_tail(&call->link, &rxrpc_calls);
 	write_unlock_bh(&rxrpc_call_lock);
 
+	/* Record copies of information for hashtable lookup */
+	call->proto = rx->proto;
+	call->local = conn->trans->local;
+	switch (call->proto) {
+	case AF_INET:
+		call->peer_ip.ipv4_addr =
+			conn->trans->peer->srx.transport.sin.sin_addr.s_addr;
+		break;
+	case AF_INET6:
+		memcpy(call->peer_ip.ipv6_addr,
+		       conn->trans->peer->srx.transport.sin6.sin6_addr.in6_u.u6_addr8,
+		       sizeof(call->peer_ip.ipv6_addr));
+		break;
+	default:
+		break;
+	}
+	call->epoch = conn->epoch;
+	call->service_id = conn->service_id;
+	call->in_clientflag = conn->in_clientflag;
+	/* Add the new call to the hashtable */
+	rxrpc_call_hash_add(call);
+
 	_net("CALL incoming %d on CONN %d", call->debug_id, call->conn->debug_id);
 
 	call->lifetimer.expires = jiffies + rxrpc_max_call_lifetime;
@@ -673,6 +859,9 @@
 		rxrpc_put_connection(call->conn);
 	}
 
+	/* Remove the call from the hash */
+	rxrpc_call_hash_del(call);
+
 	if (call->acks_window) {
 		_debug("kill Tx window %d",
 		       CIRC_CNT(call->acks_head, call->acks_tail,