net/smc: add pnet table namespace support

This patch adds namespace support to the pnet table code. Each network
namespace gets its own pnet table. Infiniband and smcd device pnetids
can only be modified in the initial namespace. In other namespaces they
can still be used as if they were set by the underlying hardware.

Signed-off-by: Hans Wippel <hwippel@linux.ibm.com>
Signed-off-by: Ursula Braun <ubraun@linux.ibm.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c
index 46fa9f3..77ef535 100644
--- a/net/smc/af_smc.c
+++ b/net/smc/af_smc.c
@@ -30,6 +30,10 @@
 #include <net/smc.h>
 #include <asm/ioctls.h>
 
+#include <net/net_namespace.h>
+#include <net/netns/generic.h>
+#include "smc_netns.h"
+
 #include "smc.h"
 #include "smc_clc.h"
 #include "smc_llc.h"
@@ -1966,10 +1970,33 @@ static const struct net_proto_family smc_sock_family_ops = {
 	.create	= smc_create,
 };
 
+unsigned int smc_net_id;
+
+static __net_init int smc_net_init(struct net *net)
+{
+	return smc_pnet_net_init(net);
+}
+
+static void __net_exit smc_net_exit(struct net *net)
+{
+	smc_pnet_net_exit(net);
+}
+
+static struct pernet_operations smc_net_ops = {
+	.init = smc_net_init,
+	.exit = smc_net_exit,
+	.id   = &smc_net_id,
+	.size = sizeof(struct smc_net),
+};
+
 static int __init smc_init(void)
 {
 	int rc;
 
+	rc = register_pernet_subsys(&smc_net_ops);
+	if (rc)
+		return rc;
+
 	rc = smc_pnet_init();
 	if (rc)
 		return rc;
@@ -2035,6 +2062,7 @@ static void __exit smc_exit(void)
 	proto_unregister(&smc_proto6);
 	proto_unregister(&smc_proto);
 	smc_pnet_exit();
+	unregister_pernet_subsys(&smc_net_ops);
 }
 
 module_init(smc_init);
diff --git a/net/smc/smc_netns.h b/net/smc/smc_netns.h
new file mode 100644
index 0000000..e7a8fc4
--- /dev/null
+++ b/net/smc/smc_netns.h
@@ -0,0 +1,20 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+/* Shared Memory Communications
+ *
+ * Network namespace definitions.
+ *
+ * Copyright IBM Corp. 2018
+ */
+
+#ifndef SMC_NETNS_H
+#define SMC_NETNS_H
+
+#include "smc_pnet.h"
+
+extern unsigned int smc_net_id;
+
+/* per-network namespace private data */
+struct smc_net {
+	struct smc_pnettable pnettable;
+};
+#endif
diff --git a/net/smc/smc_pnet.c b/net/smc/smc_pnet.c
index 5497a8b..878f5c0 100644
--- a/net/smc/smc_pnet.c
+++ b/net/smc/smc_pnet.c
@@ -20,6 +20,9 @@
 
 #include <rdma/ib_verbs.h>
 
+#include <net/netns/generic.h>
+#include "smc_netns.h"
+
 #include "smc_pnet.h"
 #include "smc_ib.h"
 #include "smc_ism.h"
@@ -47,19 +50,6 @@ static struct nla_policy smc_pnet_policy[SMC_PNETID_MAX + 1] = {
 static struct genl_family smc_pnet_nl_family;
 
 /**
- * struct smc_pnettable - SMC PNET table anchor
- * @lock: Lock for list action
- * @pnetlist: List of PNETIDs
- */
-static struct smc_pnettable {
-	rwlock_t lock;
-	struct list_head pnetlist;
-} smc_pnettable = {
-	.pnetlist = LIST_HEAD_INIT(smc_pnettable.pnetlist),
-	.lock = __RW_LOCK_UNLOCKED(smc_pnettable.lock)
-};
-
-/**
  * struct smc_user_pnetentry - pnet identifier name entry for/from user
  * @list: List node.
  * @pnet_name: Pnet identifier name
@@ -101,17 +91,23 @@ static bool smc_pnet_match(u8 *pnetid1, u8 *pnetid2)
 
 /* Remove a pnetid from the pnet table.
  */
-static int smc_pnet_remove_by_pnetid(char *pnet_name)
+static int smc_pnet_remove_by_pnetid(struct net *net, char *pnet_name)
 {
 	struct smc_pnetentry *pnetelem, *tmp_pe;
+	struct smc_pnettable *pnettable;
 	struct smc_ib_device *ibdev;
 	struct smcd_dev *smcd_dev;
+	struct smc_net *sn;
 	int rc = -ENOENT;
 	int ibport;
 
+	/* get pnettable for namespace */
+	sn = net_generic(net, smc_net_id);
+	pnettable = &sn->pnettable;
+
 	/* remove netdevices */
-	write_lock(&smc_pnettable.lock);
-	list_for_each_entry_safe(pnetelem, tmp_pe, &smc_pnettable.pnetlist,
+	write_lock(&pnettable->lock);
+	list_for_each_entry_safe(pnetelem, tmp_pe, &pnettable->pnetlist,
 				 list) {
 		if (!pnet_name ||
 		    smc_pnet_match(pnetelem->pnet_name, pnet_name)) {
@@ -121,7 +117,12 @@ static int smc_pnet_remove_by_pnetid(char *pnet_name)
 			rc = 0;
 		}
 	}
-	write_unlock(&smc_pnettable.lock);
+	write_unlock(&pnettable->lock);
+
+	/* if this is not the initial namespace, stop here */
+	if (net != &init_net)
+		return rc;
+
 	/* remove ib devices */
 	spin_lock(&smc_ib_devices.lock);
 	list_for_each_entry(ibdev, &smc_ib_devices.list, list) {
@@ -158,11 +159,17 @@ static int smc_pnet_remove_by_pnetid(char *pnet_name)
 static int smc_pnet_remove_by_ndev(struct net_device *ndev)
 {
 	struct smc_pnetentry *pnetelem, *tmp_pe;
+	struct smc_pnettable *pnettable;
+	struct net *net = dev_net(ndev);
+	struct smc_net *sn;
 	int rc = -ENOENT;
 
-	write_lock(&smc_pnettable.lock);
-	list_for_each_entry_safe(pnetelem, tmp_pe, &smc_pnettable.pnetlist,
-				 list) {
+	/* get pnettable for namespace */
+	sn = net_generic(net, smc_net_id);
+	pnettable = &sn->pnettable;
+
+	write_lock(&pnettable->lock);
+	list_for_each_entry_safe(pnetelem, tmp_pe, &pnettable->pnetlist, list) {
 		if (pnetelem->ndev == ndev) {
 			list_del(&pnetelem->list);
 			dev_put(pnetelem->ndev);
@@ -171,13 +178,14 @@ static int smc_pnet_remove_by_ndev(struct net_device *ndev)
 			break;
 		}
 	}
-	write_unlock(&smc_pnettable.lock);
+	write_unlock(&pnettable->lock);
 	return rc;
 }
 
 /* Append a pnetid to the end of the pnet table if not already on this list.
  */
-static int smc_pnet_enter(struct smc_user_pnetentry *new_pnetelem)
+static int smc_pnet_enter(struct smc_pnettable *pnettable,
+			  struct smc_user_pnetentry *new_pnetelem)
 {
 	u8 pnet_null[SMC_MAX_PNETID_LEN] = {0};
 	u8 ndev_pnetid[SMC_MAX_PNETID_LEN];
@@ -233,17 +241,17 @@ static int smc_pnet_enter(struct smc_user_pnetentry *new_pnetelem)
 	       SMC_MAX_PNETID_LEN);
 	tmp_pnetelem->ndev = new_pnetelem->ndev;
 
-	write_lock(&smc_pnettable.lock);
-	list_for_each_entry(pnetelem, &smc_pnettable.pnetlist, list) {
+	write_lock(&pnettable->lock);
+	list_for_each_entry(pnetelem, &pnettable->pnetlist, list) {
 		if (pnetelem->ndev == new_pnetelem->ndev)
 			new_netdev = false;
 	}
 	if (new_netdev) {
 		dev_hold(tmp_pnetelem->ndev);
-		list_add_tail(&tmp_pnetelem->list, &smc_pnettable.pnetlist);
-		write_unlock(&smc_pnettable.lock);
+		list_add_tail(&tmp_pnetelem->list, &pnettable->pnetlist);
+		write_unlock(&pnettable->lock);
 	} else {
-		write_unlock(&smc_pnettable.lock);
+		write_unlock(&pnettable->lock);
 		kfree(tmp_pnetelem);
 	}
 
@@ -340,6 +348,10 @@ static int smc_pnet_fill_entry(struct net *net,
 			goto error;
 	}
 
+	/* if this is not the initial namespace, stop here */
+	if (net != &init_net)
+		return 0;
+
 	rc = -EINVAL;
 	if (tb[SMC_PNETID_IBNAME]) {
 		ibname = (char *)nla_data(tb[SMC_PNETID_IBNAME]);
@@ -403,11 +415,17 @@ static int smc_pnet_add(struct sk_buff *skb, struct genl_info *info)
 {
 	struct net *net = genl_info_net(info);
 	struct smc_user_pnetentry pnetelem;
+	struct smc_pnettable *pnettable;
+	struct smc_net *sn;
 	int rc;
 
+	/* get pnettable for namespace */
+	sn = net_generic(net, smc_net_id);
+	pnettable = &sn->pnettable;
+
 	rc = smc_pnet_fill_entry(net, &pnetelem, info->attrs);
 	if (!rc)
-		rc = smc_pnet_enter(&pnetelem);
+		rc = smc_pnet_enter(pnettable, &pnetelem);
 	if (pnetelem.ndev)
 		dev_put(pnetelem.ndev);
 	return rc;
@@ -415,9 +433,11 @@ static int smc_pnet_add(struct sk_buff *skb, struct genl_info *info)
 
 static int smc_pnet_del(struct sk_buff *skb, struct genl_info *info)
 {
+	struct net *net = genl_info_net(info);
+
 	if (!info->attrs[SMC_PNETID_NAME])
 		return -EINVAL;
-	return smc_pnet_remove_by_pnetid(
+	return smc_pnet_remove_by_pnetid(net,
 				(char *)nla_data(info->attrs[SMC_PNETID_NAME]));
 }
 
@@ -445,19 +465,25 @@ static int smc_pnet_dumpinfo(struct sk_buff *skb,
 	return 0;
 }
 
-static int _smc_pnet_dump(struct sk_buff *skb, u32 portid, u32 seq, u8 *pnetid,
-			  int start_idx)
+static int _smc_pnet_dump(struct net *net, struct sk_buff *skb, u32 portid,
+			  u32 seq, u8 *pnetid, int start_idx)
 {
 	struct smc_user_pnetentry tmp_entry;
+	struct smc_pnettable *pnettable;
 	struct smc_pnetentry *pnetelem;
 	struct smc_ib_device *ibdev;
 	struct smcd_dev *smcd_dev;
+	struct smc_net *sn;
 	int idx = 0;
 	int ibport;
 
+	/* get pnettable for namespace */
+	sn = net_generic(net, smc_net_id);
+	pnettable = &sn->pnettable;
+
 	/* dump netdevices */
-	read_lock(&smc_pnettable.lock);
-	list_for_each_entry(pnetelem, &smc_pnettable.pnetlist, list) {
+	read_lock(&pnettable->lock);
+	list_for_each_entry(pnetelem, &pnettable->pnetlist, list) {
 		if (pnetid && !smc_pnet_match(pnetelem->pnet_name, pnetid))
 			continue;
 		if (idx++ < start_idx)
@@ -472,7 +498,11 @@ static int _smc_pnet_dump(struct sk_buff *skb, u32 portid, u32 seq, u8 *pnetid,
 			break;
 		}
 	}
-	read_unlock(&smc_pnettable.lock);
+	read_unlock(&pnettable->lock);
+
+	/* if this is not the initial namespace, stop here */
+	if (net != &init_net)
+		return idx;
 
 	/* dump ib devices */
 	spin_lock(&smc_ib_devices.lock);
@@ -528,9 +558,10 @@ static int _smc_pnet_dump(struct sk_buff *skb, u32 portid, u32 seq, u8 *pnetid,
 
 static int smc_pnet_dump(struct sk_buff *skb, struct netlink_callback *cb)
 {
+	struct net *net = sock_net(skb->sk);
 	int idx;
 
-	idx = _smc_pnet_dump(skb, NETLINK_CB(cb->skb).portid,
+	idx = _smc_pnet_dump(net, skb, NETLINK_CB(cb->skb).portid,
 			     cb->nlh->nlmsg_seq, NULL, cb->args[0]);
 
 	cb->args[0] = idx;
@@ -540,6 +571,7 @@ static int smc_pnet_dump(struct sk_buff *skb, struct netlink_callback *cb)
 /* Retrieve one PNETID entry */
 static int smc_pnet_get(struct sk_buff *skb, struct genl_info *info)
 {
+	struct net *net = genl_info_net(info);
 	struct sk_buff *msg;
 	void *hdr;
 
@@ -550,7 +582,7 @@ static int smc_pnet_get(struct sk_buff *skb, struct genl_info *info)
 	if (!msg)
 		return -ENOMEM;
 
-	_smc_pnet_dump(msg, info->snd_portid, info->snd_seq,
+	_smc_pnet_dump(net, msg, info->snd_portid, info->snd_seq,
 		       nla_data(info->attrs[SMC_PNETID_NAME]), 0);
 
 	/* finish multi part message and send it */
@@ -567,7 +599,9 @@ static int smc_pnet_get(struct sk_buff *skb, struct genl_info *info)
  */
 static int smc_pnet_flush(struct sk_buff *skb, struct genl_info *info)
 {
-	return smc_pnet_remove_by_pnetid(NULL);
+	struct net *net = genl_info_net(info);
+
+	return smc_pnet_remove_by_pnetid(net, NULL);
 }
 
 /* SMC_PNETID generic netlink operation definition */
@@ -631,6 +665,18 @@ static struct notifier_block smc_netdev_notifier = {
 	.notifier_call = smc_pnet_netdev_event
 };
 
+/* init network namespace */
+int smc_pnet_net_init(struct net *net)
+{
+	struct smc_net *sn = net_generic(net, smc_net_id);
+	struct smc_pnettable *pnettable = &sn->pnettable;
+
+	INIT_LIST_HEAD(&pnettable->pnetlist);
+	rwlock_init(&pnettable->lock);
+
+	return 0;
+}
+
 int __init smc_pnet_init(void)
 {
 	int rc;
@@ -644,9 +690,15 @@ int __init smc_pnet_init(void)
 	return rc;
 }
 
+/* exit network namespace */
+void smc_pnet_net_exit(struct net *net)
+{
+	/* flush pnet table */
+	smc_pnet_remove_by_pnetid(net, NULL);
+}
+
 void smc_pnet_exit(void)
 {
-	smc_pnet_flush(NULL, NULL);
 	unregister_netdevice_notifier(&smc_netdev_notifier);
 	genl_unregister_family(&smc_pnet_nl_family);
 }
@@ -674,22 +726,29 @@ static struct net_device *pnet_find_base_ndev(struct net_device *ndev)
 	return ndev;
 }
 
-static int smc_pnet_find_ndev_pnetid_by_table(struct net_device *netdev,
+static int smc_pnet_find_ndev_pnetid_by_table(struct net_device *ndev,
 					      u8 *pnetid)
 {
+	struct smc_pnettable *pnettable;
+	struct net *net = dev_net(ndev);
 	struct smc_pnetentry *pnetelem;
+	struct smc_net *sn;
 	int rc = -ENOENT;
 
-	read_lock(&smc_pnettable.lock);
-	list_for_each_entry(pnetelem, &smc_pnettable.pnetlist, list) {
-		if (netdev == pnetelem->ndev) {
+	/* get pnettable for namespace */
+	sn = net_generic(net, smc_net_id);
+	pnettable = &sn->pnettable;
+
+	read_lock(&pnettable->lock);
+	list_for_each_entry(pnetelem, &pnettable->pnetlist, list) {
+		if (ndev == pnetelem->ndev) {
 			/* get pnetid of netdev device */
 			memcpy(pnetid, pnetelem->pnet_name, SMC_MAX_PNETID_LEN);
 			rc = 0;
 			break;
 		}
 	}
-	read_unlock(&smc_pnettable.lock);
+	read_unlock(&pnettable->lock);
 	return rc;
 }
 
diff --git a/net/smc/smc_pnet.h b/net/smc/smc_pnet.h
index 37044e4..5eac42f 100644
--- a/net/smc/smc_pnet.h
+++ b/net/smc/smc_pnet.h
@@ -19,6 +19,16 @@
 struct smc_ib_device;
 struct smcd_dev;
 
+/**
+ * struct smc_pnettable - SMC PNET table anchor
+ * @lock: Lock for list action
+ * @pnetlist: List of PNETIDs
+ */
+struct smc_pnettable {
+	rwlock_t lock;
+	struct list_head pnetlist;
+};
+
 static inline int smc_pnetid_by_dev_port(struct device *dev,
 					 unsigned short port, u8 *pnetid)
 {
@@ -30,7 +40,9 @@ static inline int smc_pnetid_by_dev_port(struct device *dev,
 }
 
 int smc_pnet_init(void) __init;
+int smc_pnet_net_init(struct net *net);
 void smc_pnet_exit(void);
+void smc_pnet_net_exit(struct net *net);
 void smc_pnet_find_roce_resource(struct sock *sk,
 				 struct smc_ib_device **smcibdev, u8 *ibport,
 				 unsigned short vlan_id, u8 gid[]);