tls: rx: device: add input CoW helper

Wrap the remaining skb_cow_data() into a helper, so it's easier
to replace down the lane. The new version will change the skb
so make sure relevant pointers get reloaded after the call.

Signed-off-by: Jakub Kicinski <kuba@kernel.org>

diff --git a/net/tls/tls.h b/net/tls/tls.h
index 78c5d69..154a377 100644
--- a/net/tls/tls.h
+++ b/net/tls/tls.h
@@ -127,6 +127,7 @@ int tls_sw_fallback_init(struct sock *sk,
 			 struct tls_offload_context_tx *offload_ctx,
 			 struct tls_crypto_info *crypto_info);
 
+int tls_strp_msg_cow(struct tls_sw_context_rx *ctx);
 struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx);
 int tls_strp_msg_hold(struct sock *sk, struct sk_buff *skb,
 		      struct sk_buff_head *dst);
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index b1fcd61..fc513c1 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -894,27 +894,26 @@ static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
 static int
 tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx)
 {
-	int err = 0, offset, copy, nsg, data_len, pos;
-	struct sk_buff *skb, *skb_iter, *unused;
+	int err, offset, copy, data_len, pos;
+	struct sk_buff *skb, *skb_iter;
 	struct scatterlist sg[1];
 	struct strp_msg *rxm;
 	char *orig_buf, *buf;
 
-	skb = tls_strp_msg(sw_ctx);
-	rxm = strp_msg(skb);
-	offset = rxm->offset;
-
+	rxm = strp_msg(tls_strp_msg(sw_ctx));
 	orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
 			   TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
 	if (!orig_buf)
 		return -ENOMEM;
 	buf = orig_buf;
 
-	nsg = skb_cow_data(skb, 0, &unused);
-	if (unlikely(nsg < 0)) {
-		err = nsg;
+	err = tls_strp_msg_cow(sw_ctx);
+	if (unlikely(err))
 		goto free_buf;
-	}
+
+	skb = tls_strp_msg(sw_ctx);
+	rxm = strp_msg(skb);
+	offset = rxm->offset;
 
 	sg_init_table(sg, 1);
 	sg_set_buf(&sg[0], buf,
diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c
index 40b1773..d9bb4f23 100644
--- a/net/tls/tls_strp.c
+++ b/net/tls/tls_strp.c
@@ -13,6 +13,17 @@ struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx)
 	return skb;
 }
 
+int tls_strp_msg_cow(struct tls_sw_context_rx *ctx)
+{
+	struct sk_buff *unused;
+	int nsg;
+
+	nsg = skb_cow_data(ctx->recv_pkt, 0, &unused);
+	if (nsg < 0)
+		return nsg;
+	return 0;
+}
+
 int tls_strp_msg_hold(struct sock *sk, struct sk_buff *skb,
 		      struct sk_buff_head *dst)
 {