devlink: convert occ_get op to separate registration

This resolves race during initialization where the resources with
ops are registered before driver and the structures used by occ_get
op is initialized. So keep occ_get callbacks registered only when
all structs are initialized.

The example flows, as it is in mlxsw:
1) driver load/asic probe:
   mlxsw_core
      -> mlxsw_sp_resources_register
        -> mlxsw_sp_kvdl_resources_register
          -> devlink_resource_register IDX
   mlxsw_spectrum
      -> mlxsw_sp_kvdl_init
        -> mlxsw_sp_kvdl_parts_init
          -> mlxsw_sp_kvdl_part_init
            -> devlink_resource_size_get IDX (to get the current setup
                                              size from devlink)
        -> devlink_resource_occ_get_register IDX (register current
                                                  occupancy getter)
2) reload triggered by devlink command:
  -> mlxsw_devlink_core_bus_device_reload
    -> mlxsw_sp_fini
      -> mlxsw_sp_kvdl_fini
	-> devlink_resource_occ_get_unregister IDX
    (struct mlxsw_sp *mlxsw_sp is freed at this point, call to occ get
     which is using mlxsw_sp would cause use-after free)
    -> mlxsw_sp_init
      -> mlxsw_sp_kvdl_init
        -> mlxsw_sp_kvdl_parts_init
          -> mlxsw_sp_kvdl_part_init
            -> devlink_resource_size_get IDX (to get the current setup
                                              size from devlink)
        -> devlink_resource_occ_get_register IDX (register current
                                                  occupancy getter)

Fixes: d9f9b9a4d05f ("devlink: Add support for resource abstraction")
Signed-off-by: Jiri Pirko <jiri@mellanox.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/core/devlink.c b/net/core/devlink.c
index 9236e42..ad13173 100644
--- a/net/core/devlink.c
+++ b/net/core/devlink.c
@@ -2405,6 +2405,16 @@ devlink_resource_size_params_put(struct devlink_resource *resource,
 	return 0;
 }
 
+static int devlink_resource_occ_put(struct devlink_resource *resource,
+				    struct sk_buff *skb)
+{
+	if (!resource->occ_get)
+		return 0;
+	return nla_put_u64_64bit(skb, DEVLINK_ATTR_RESOURCE_OCC,
+				 resource->occ_get(resource->occ_get_priv),
+				 DEVLINK_ATTR_PAD);
+}
+
 static int devlink_resource_put(struct devlink *devlink, struct sk_buff *skb,
 				struct devlink_resource *resource)
 {
@@ -2425,11 +2435,8 @@ static int devlink_resource_put(struct devlink *devlink, struct sk_buff *skb,
 	if (resource->size != resource->size_new)
 		nla_put_u64_64bit(skb, DEVLINK_ATTR_RESOURCE_SIZE_NEW,
 				  resource->size_new, DEVLINK_ATTR_PAD);
-	if (resource->resource_ops && resource->resource_ops->occ_get)
-		if (nla_put_u64_64bit(skb, DEVLINK_ATTR_RESOURCE_OCC,
-				      resource->resource_ops->occ_get(devlink),
-				      DEVLINK_ATTR_PAD))
-			goto nla_put_failure;
+	if (devlink_resource_occ_put(resource, skb))
+		goto nla_put_failure;
 	if (devlink_resource_size_params_put(resource, skb))
 		goto nla_put_failure;
 	if (list_empty(&resource->resource_list))
@@ -3162,15 +3169,13 @@ EXPORT_SYMBOL_GPL(devlink_dpipe_table_unregister);
  *	@resource_id: resource's id
  *	@parent_reosurce_id: resource's parent id
  *	@size params: size parameters
- *	@resource_ops: resource ops
  */
 int devlink_resource_register(struct devlink *devlink,
 			      const char *resource_name,
 			      u64 resource_size,
 			      u64 resource_id,
 			      u64 parent_resource_id,
-			      const struct devlink_resource_size_params *size_params,
-			      const struct devlink_resource_ops *resource_ops)
+			      const struct devlink_resource_size_params *size_params)
 {
 	struct devlink_resource *resource;
 	struct list_head *resource_list;
@@ -3213,7 +3218,6 @@ int devlink_resource_register(struct devlink *devlink,
 	resource->size = resource_size;
 	resource->size_new = resource_size;
 	resource->id = resource_id;
-	resource->resource_ops = resource_ops;
 	resource->size_valid = true;
 	memcpy(&resource->size_params, size_params,
 	       sizeof(resource->size_params));
@@ -3315,6 +3319,58 @@ int devlink_dpipe_table_resource_set(struct devlink *devlink,
 }
 EXPORT_SYMBOL_GPL(devlink_dpipe_table_resource_set);
 
+/**
+ *	devlink_resource_occ_get_register - register occupancy getter
+ *
+ *	@devlink: devlink
+ *	@resource_id: resource id
+ *	@occ_get: occupancy getter callback
+ *	@occ_get_priv: occupancy getter callback priv
+ */
+void devlink_resource_occ_get_register(struct devlink *devlink,
+				       u64 resource_id,
+				       devlink_resource_occ_get_t *occ_get,
+				       void *occ_get_priv)
+{
+	struct devlink_resource *resource;
+
+	mutex_lock(&devlink->lock);
+	resource = devlink_resource_find(devlink, NULL, resource_id);
+	if (WARN_ON(!resource))
+		goto out;
+	WARN_ON(resource->occ_get);
+
+	resource->occ_get = occ_get;
+	resource->occ_get_priv = occ_get_priv;
+out:
+	mutex_unlock(&devlink->lock);
+}
+EXPORT_SYMBOL_GPL(devlink_resource_occ_get_register);
+
+/**
+ *	devlink_resource_occ_get_unregister - unregister occupancy getter
+ *
+ *	@devlink: devlink
+ *	@resource_id: resource id
+ */
+void devlink_resource_occ_get_unregister(struct devlink *devlink,
+					 u64 resource_id)
+{
+	struct devlink_resource *resource;
+
+	mutex_lock(&devlink->lock);
+	resource = devlink_resource_find(devlink, NULL, resource_id);
+	if (WARN_ON(!resource))
+		goto out;
+	WARN_ON(!resource->occ_get);
+
+	resource->occ_get = NULL;
+	resource->occ_get_priv = NULL;
+out:
+	mutex_unlock(&devlink->lock);
+}
+EXPORT_SYMBOL_GPL(devlink_resource_occ_get_unregister);
+
 static int __init devlink_module_init(void)
 {
 	return genl_register_family(&devlink_nl_family);