PCI: Consolidate "next-function" functions

There are several next_fn functions (no_next_fn, next_trad_fn,
next_ari_fn); consolidate them in next_fn() to simplify the code.

[bhelgaas: make next_fn() static, rework control flow]
Signed-off-by: Yijing Wang <wangyijing@huawei.com>
Signed-off-by: Jiang Liu <jiang.liu@huawei.com>
Signed-off-by: Bjorn Helgaas <bhelgaas@google.com>
diff --git a/drivers/pci/probe.c b/drivers/pci/probe.c
index 7b9e691..2a9958c 100644
--- a/drivers/pci/probe.c
+++ b/drivers/pci/probe.c
@@ -1349,31 +1349,31 @@
 }
 EXPORT_SYMBOL(pci_scan_single_device);
 
-static unsigned next_ari_fn(struct pci_dev *dev, unsigned fn)
+static unsigned next_fn(struct pci_bus *bus, struct pci_dev *dev, unsigned fn)
 {
-	u16 cap;
-	unsigned pos, next_fn;
+	int pos;
+	u16 cap = 0;
+	unsigned next_fn;
 
-	if (!dev)
-		return 0;
+	if (pci_ari_enabled(bus)) {
+		if (!dev)
+			return 0;
+		pos = pci_find_ext_capability(dev, PCI_EXT_CAP_ID_ARI);
+		if (!pos)
+			return 0;
 
-	pos = pci_find_ext_capability(dev, PCI_EXT_CAP_ID_ARI);
-	if (!pos)
-		return 0;
-	pci_read_config_word(dev, pos + 4, &cap);
-	next_fn = cap >> 8;
-	if (next_fn <= fn)
-		return 0;
-	return next_fn;
-}
+		pci_read_config_word(dev, pos + PCI_ARI_CAP, &cap);
+		next_fn = PCI_ARI_CAP_NFN(cap);
+		if (next_fn <= fn)
+			return 0;	/* protect against malformed list */
 
-static unsigned next_trad_fn(struct pci_dev *dev, unsigned fn)
-{
-	return (fn + 1) % 8;
-}
+		return next_fn;
+	}
 
-static unsigned no_next_fn(struct pci_dev *dev, unsigned fn)
-{
+	/* dev may be NULL for non-contiguous multifunction devices */
+	if (!dev || dev->multifunction)
+		return (fn + 1) % 8;
+
 	return 0;
 }
 
@@ -1406,7 +1406,6 @@
 {
 	unsigned fn, nr = 0;
 	struct pci_dev *dev;
-	unsigned (*next_fn)(struct pci_dev *, unsigned) = no_next_fn;
 
 	if (only_one_child(bus) && (devfn > 0))
 		return 0; /* Already scanned the entire slot */
@@ -1417,12 +1416,7 @@
 	if (!dev->is_added)
 		nr++;
 
-	if (pci_ari_enabled(bus))
-		next_fn = next_ari_fn;
-	else if (dev->multifunction)
-		next_fn = next_trad_fn;
-
-	for (fn = next_fn(dev, 0); fn > 0; fn = next_fn(dev, fn)) {
+	for (fn = next_fn(bus, dev, 0); fn > 0; fn = next_fn(bus, dev, fn)) {
 		dev = pci_scan_single_device(bus, devfn + fn);
 		if (dev) {
 			if (!dev->is_added)