VSOCK: sock_put wasn't safe to call in interrupt context

In the vsock vmci_transport driver, sock_put wasn't safe to call
in interrupt context, since that may call the vsock destructor
which in turn calls several functions that should only be called
from process context. This change defers the callling of these
functions  to a worker thread. All these functions were
deallocation of resources related to the transport itself.

Furthermore, an unused callback was removed to simplify the
cleanup.

Multiple customers have been hitting this issue when using
VMware tools on vSphere 2015.

Also added a version to the vmci transport module (starting from
1.0.2.0-k since up until now it appears that this module was
sharing version with vsock that is currently at 1.0.1.0-k).

Reviewed-by: Aditya Asarwade <asarwade@vmware.com>
Reviewed-by: Thomas Hellstrom <thellstrom@vmware.com>
Signed-off-by: Jorgen Hansen <jhansen@vmware.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c
index 1f63daf..5243ce2 100644
--- a/net/vmw_vsock/vmci_transport.c
+++ b/net/vmw_vsock/vmci_transport.c
@@ -40,13 +40,11 @@
 
 static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg);
 static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg);
-static void vmci_transport_peer_attach_cb(u32 sub_id,
-					  const struct vmci_event_data *ed,
-					  void *client_data);
 static void vmci_transport_peer_detach_cb(u32 sub_id,
 					  const struct vmci_event_data *ed,
 					  void *client_data);
 static void vmci_transport_recv_pkt_work(struct work_struct *work);
+static void vmci_transport_cleanup(struct work_struct *work);
 static int vmci_transport_recv_listen(struct sock *sk,
 				      struct vmci_transport_packet *pkt);
 static int vmci_transport_recv_connecting_server(
@@ -75,6 +73,10 @@
 	struct vmci_transport_packet pkt;
 };
 
+static LIST_HEAD(vmci_transport_cleanup_list);
+static DEFINE_SPINLOCK(vmci_transport_cleanup_lock);
+static DECLARE_WORK(vmci_transport_cleanup_work, vmci_transport_cleanup);
+
 static struct vmci_handle vmci_transport_stream_handle = { VMCI_INVALID_ID,
 							   VMCI_INVALID_ID };
 static u32 vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
@@ -791,44 +793,6 @@
 	return err;
 }
 
-static void vmci_transport_peer_attach_cb(u32 sub_id,
-					  const struct vmci_event_data *e_data,
-					  void *client_data)
-{
-	struct sock *sk = client_data;
-	const struct vmci_event_payload_qp *e_payload;
-	struct vsock_sock *vsk;
-
-	e_payload = vmci_event_data_const_payload(e_data);
-
-	vsk = vsock_sk(sk);
-
-	/* We don't ask for delayed CBs when we subscribe to this event (we
-	 * pass 0 as flags to vmci_event_subscribe()).  VMCI makes no
-	 * guarantees in that case about what context we might be running in,
-	 * so it could be BH or process, blockable or non-blockable.  So we
-	 * need to account for all possible contexts here.
-	 */
-	local_bh_disable();
-	bh_lock_sock(sk);
-
-	/* XXX This is lame, we should provide a way to lookup sockets by
-	 * qp_handle.
-	 */
-	if (vmci_handle_is_equal(vmci_trans(vsk)->qp_handle,
-				 e_payload->handle)) {
-		/* XXX This doesn't do anything, but in the future we may want
-		 * to set a flag here to verify the attach really did occur and
-		 * we weren't just sent a datagram claiming it was.
-		 */
-		goto out;
-	}
-
-out:
-	bh_unlock_sock(sk);
-	local_bh_enable();
-}
-
 static void vmci_transport_handle_detach(struct sock *sk)
 {
 	struct vsock_sock *vsk;
@@ -871,28 +835,38 @@
 					  const struct vmci_event_data *e_data,
 					  void *client_data)
 {
-	struct sock *sk = client_data;
+	struct vmci_transport *trans = client_data;
 	const struct vmci_event_payload_qp *e_payload;
-	struct vsock_sock *vsk;
 
 	e_payload = vmci_event_data_const_payload(e_data);
-	vsk = vsock_sk(sk);
-	if (vmci_handle_is_invalid(e_payload->handle))
-		return;
-
-	/* Same rules for locking as for peer_attach_cb(). */
-	local_bh_disable();
-	bh_lock_sock(sk);
 
 	/* XXX This is lame, we should provide a way to lookup sockets by
 	 * qp_handle.
 	 */
-	if (vmci_handle_is_equal(vmci_trans(vsk)->qp_handle,
-				 e_payload->handle))
-		vmci_transport_handle_detach(sk);
+	if (vmci_handle_is_invalid(e_payload->handle) ||
+	    vmci_handle_is_equal(trans->qp_handle, e_payload->handle))
+		return;
 
-	bh_unlock_sock(sk);
-	local_bh_enable();
+	/* We don't ask for delayed CBs when we subscribe to this event (we
+	 * pass 0 as flags to vmci_event_subscribe()).  VMCI makes no
+	 * guarantees in that case about what context we might be running in,
+	 * so it could be BH or process, blockable or non-blockable.  So we
+	 * need to account for all possible contexts here.
+	 */
+	spin_lock_bh(&trans->lock);
+	if (!trans->sk)
+		goto out;
+
+	/* Apart from here, trans->lock is only grabbed as part of sk destruct,
+	 * where trans->sk isn't locked.
+	 */
+	bh_lock_sock(trans->sk);
+
+	vmci_transport_handle_detach(trans->sk);
+
+	bh_unlock_sock(trans->sk);
+ out:
+	spin_unlock_bh(&trans->lock);
 }
 
 static void vmci_transport_qp_resumed_cb(u32 sub_id,
@@ -1181,7 +1155,7 @@
 	 */
 	err = vmci_event_subscribe(VMCI_EVENT_QP_PEER_DETACH,
 				   vmci_transport_peer_detach_cb,
-				   pending, &detach_sub_id);
+				   vmci_trans(vpending), &detach_sub_id);
 	if (err < VMCI_SUCCESS) {
 		vmci_transport_send_reset(pending, pkt);
 		err = vmci_transport_error_to_vsock_error(err);
@@ -1321,7 +1295,6 @@
 		    || vmci_trans(vsk)->qpair
 		    || vmci_trans(vsk)->produce_size != 0
 		    || vmci_trans(vsk)->consume_size != 0
-		    || vmci_trans(vsk)->attach_sub_id != VMCI_INVALID_ID
 		    || vmci_trans(vsk)->detach_sub_id != VMCI_INVALID_ID) {
 			skerr = EPROTO;
 			err = -EINVAL;
@@ -1389,7 +1362,6 @@
 	struct vsock_sock *vsk;
 	struct vmci_handle handle;
 	struct vmci_qp *qpair;
-	u32 attach_sub_id;
 	u32 detach_sub_id;
 	bool is_local;
 	u32 flags;
@@ -1399,7 +1371,6 @@
 
 	vsk = vsock_sk(sk);
 	handle = VMCI_INVALID_HANDLE;
-	attach_sub_id = VMCI_INVALID_ID;
 	detach_sub_id = VMCI_INVALID_ID;
 
 	/* If we have gotten here then we should be past the point where old
@@ -1444,23 +1415,15 @@
 		goto destroy;
 	}
 
-	/* Subscribe to attach and detach events first.
+	/* Subscribe to detach events first.
 	 *
 	 * XXX We attach once for each queue pair created for now so it is easy
 	 * to find the socket (it's provided), but later we should only
 	 * subscribe once and add a way to lookup sockets by queue pair handle.
 	 */
-	err = vmci_event_subscribe(VMCI_EVENT_QP_PEER_ATTACH,
-				   vmci_transport_peer_attach_cb,
-				   sk, &attach_sub_id);
-	if (err < VMCI_SUCCESS) {
-		err = vmci_transport_error_to_vsock_error(err);
-		goto destroy;
-	}
-
 	err = vmci_event_subscribe(VMCI_EVENT_QP_PEER_DETACH,
 				   vmci_transport_peer_detach_cb,
-				   sk, &detach_sub_id);
+				   vmci_trans(vsk), &detach_sub_id);
 	if (err < VMCI_SUCCESS) {
 		err = vmci_transport_error_to_vsock_error(err);
 		goto destroy;
@@ -1496,7 +1459,6 @@
 	vmci_trans(vsk)->produce_size = vmci_trans(vsk)->consume_size =
 		pkt->u.size;
 
-	vmci_trans(vsk)->attach_sub_id = attach_sub_id;
 	vmci_trans(vsk)->detach_sub_id = detach_sub_id;
 
 	vmci_trans(vsk)->notify_ops->process_negotiate(sk);
@@ -1504,9 +1466,6 @@
 	return 0;
 
 destroy:
-	if (attach_sub_id != VMCI_INVALID_ID)
-		vmci_event_unsubscribe(attach_sub_id);
-
 	if (detach_sub_id != VMCI_INVALID_ID)
 		vmci_event_unsubscribe(detach_sub_id);
 
@@ -1607,9 +1566,11 @@
 	vmci_trans(vsk)->qp_handle = VMCI_INVALID_HANDLE;
 	vmci_trans(vsk)->qpair = NULL;
 	vmci_trans(vsk)->produce_size = vmci_trans(vsk)->consume_size = 0;
-	vmci_trans(vsk)->attach_sub_id = vmci_trans(vsk)->detach_sub_id =
-		VMCI_INVALID_ID;
+	vmci_trans(vsk)->detach_sub_id = VMCI_INVALID_ID;
 	vmci_trans(vsk)->notify_ops = NULL;
+	INIT_LIST_HEAD(&vmci_trans(vsk)->elem);
+	vmci_trans(vsk)->sk = &vsk->sk;
+	vmci_trans(vsk)->lock = __SPIN_LOCK_UNLOCKED(vmci_trans(vsk)->lock);
 	if (psk) {
 		vmci_trans(vsk)->queue_pair_size =
 			vmci_trans(psk)->queue_pair_size;
@@ -1629,29 +1590,57 @@
 	return 0;
 }
 
+static void vmci_transport_free_resources(struct list_head *transport_list)
+{
+	while (!list_empty(transport_list)) {
+		struct vmci_transport *transport =
+		    list_first_entry(transport_list, struct vmci_transport,
+				     elem);
+		list_del(&transport->elem);
+
+		if (transport->detach_sub_id != VMCI_INVALID_ID) {
+			vmci_event_unsubscribe(transport->detach_sub_id);
+			transport->detach_sub_id = VMCI_INVALID_ID;
+		}
+
+		if (!vmci_handle_is_invalid(transport->qp_handle)) {
+			vmci_qpair_detach(&transport->qpair);
+			transport->qp_handle = VMCI_INVALID_HANDLE;
+			transport->produce_size = 0;
+			transport->consume_size = 0;
+		}
+
+		kfree(transport);
+	}
+}
+
+static void vmci_transport_cleanup(struct work_struct *work)
+{
+	LIST_HEAD(pending);
+
+	spin_lock_bh(&vmci_transport_cleanup_lock);
+	list_replace_init(&vmci_transport_cleanup_list, &pending);
+	spin_unlock_bh(&vmci_transport_cleanup_lock);
+	vmci_transport_free_resources(&pending);
+}
+
 static void vmci_transport_destruct(struct vsock_sock *vsk)
 {
-	if (vmci_trans(vsk)->attach_sub_id != VMCI_INVALID_ID) {
-		vmci_event_unsubscribe(vmci_trans(vsk)->attach_sub_id);
-		vmci_trans(vsk)->attach_sub_id = VMCI_INVALID_ID;
-	}
-
-	if (vmci_trans(vsk)->detach_sub_id != VMCI_INVALID_ID) {
-		vmci_event_unsubscribe(vmci_trans(vsk)->detach_sub_id);
-		vmci_trans(vsk)->detach_sub_id = VMCI_INVALID_ID;
-	}
-
-	if (!vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)) {
-		vmci_qpair_detach(&vmci_trans(vsk)->qpair);
-		vmci_trans(vsk)->qp_handle = VMCI_INVALID_HANDLE;
-		vmci_trans(vsk)->produce_size = 0;
-		vmci_trans(vsk)->consume_size = 0;
-	}
+	/* Ensure that the detach callback doesn't use the sk/vsk
+	 * we are about to destruct.
+	 */
+	spin_lock_bh(&vmci_trans(vsk)->lock);
+	vmci_trans(vsk)->sk = NULL;
+	spin_unlock_bh(&vmci_trans(vsk)->lock);
 
 	if (vmci_trans(vsk)->notify_ops)
 		vmci_trans(vsk)->notify_ops->socket_destruct(vsk);
 
-	kfree(vsk->trans);
+	spin_lock_bh(&vmci_transport_cleanup_lock);
+	list_add(&vmci_trans(vsk)->elem, &vmci_transport_cleanup_list);
+	spin_unlock_bh(&vmci_transport_cleanup_lock);
+	schedule_work(&vmci_transport_cleanup_work);
+
 	vsk->trans = NULL;
 }
 
@@ -2146,6 +2135,9 @@
 
 static void __exit vmci_transport_exit(void)
 {
+	cancel_work_sync(&vmci_transport_cleanup_work);
+	vmci_transport_free_resources(&vmci_transport_cleanup_list);
+
 	if (!vmci_handle_is_invalid(vmci_transport_stream_handle)) {
 		if (vmci_datagram_destroy_handle(
 			vmci_transport_stream_handle) != VMCI_SUCCESS)
@@ -2164,6 +2156,7 @@
 
 MODULE_AUTHOR("VMware, Inc.");
 MODULE_DESCRIPTION("VMCI transport for Virtual Sockets");
+MODULE_VERSION("1.0.2.0-k");
 MODULE_LICENSE("GPL v2");
 MODULE_ALIAS("vmware_vsock");
 MODULE_ALIAS_NETPROTO(PF_VSOCK);
diff --git a/net/vmw_vsock/vmci_transport.h b/net/vmw_vsock/vmci_transport.h
index ce6c962..2ad46f3 100644
--- a/net/vmw_vsock/vmci_transport.h
+++ b/net/vmw_vsock/vmci_transport.h
@@ -119,10 +119,12 @@
 	u64 queue_pair_size;
 	u64 queue_pair_min_size;
 	u64 queue_pair_max_size;
-	u32 attach_sub_id;
 	u32 detach_sub_id;
 	union vmci_transport_notify notify;
 	struct vmci_transport_notify_ops *notify_ops;
+	struct list_head elem;
+	struct sock *sk;
+	spinlock_t lock; /* protects sk. */
 };
 
 int vmci_transport_register(void);