use rwlock
diff --git a/virt/kvm/guest_memfd.c b/virt/kvm/guest_memfd.c
index d7a2973..ae49be7 100644
--- a/virt/kvm/guest_memfd.c
+++ b/virt/kvm/guest_memfd.c
@@ -20,6 +20,7 @@ struct kvm_gmem {
 struct kvm_gmem_inode_private {
 #ifdef CONFIG_KVM_GMEM_SHARED_MEM
 	struct xarray shared_offsets;
+	rwlock_t offsets_lock;
 #endif
 };
 
@@ -414,15 +415,16 @@ enum folio_shareability {
  * and removes the folio type, thereby removing the callback. Now the folio can
  * be freed normaly once all actual references have been dropped.
  *
- * Must be called with the filemap (inode->i_mapping) invalidate_lock held, and
- * the folio must be locked.
+ * Must be called with the folio locked and the offsets_lock write lock held.
  */
-static void kvm_gmem_restore_pending_folio(struct folio *folio, const struct inode *inode)
+static void kvm_gmem_restore_pending_folio(struct folio *folio, struct inode *inode)
 {
-	rwsem_assert_held_write_nolockdep(&inode->i_mapping->invalidate_lock);
+	rwlock_t *offsets_lock = &kvm_gmem_private(inode)->offsets_lock;
+
+	lockdep_assert_held_write(offsets_lock);
 	WARN_ON_ONCE(!folio_test_locked(folio));
 
-	if (WARN_ON_ONCE(folio_mapped(folio) || !folio_test_guestmem(folio)))
+	if (WARN_ON_ONCE(!folio_test_guestmem(folio)))
 		return;
 
 	__folio_clear_guestmem(folio);
@@ -434,9 +436,10 @@ static void kvm_gmem_restore_pending_folio(struct folio *folio, const struct ino
 static int kvm_gmem_offset_set_shared(struct inode *inode, pgoff_t index)
 {
 	struct xarray *shared_offsets = &kvm_gmem_private(inode)->shared_offsets;
+	rwlock_t *offsets_lock = &kvm_gmem_private(inode)->offsets_lock;
 	void *xval = xa_mk_value(KVM_GMEM_ALL_SHARED);
 
-	rwsem_assert_held_write_nolockdep(&inode->i_mapping->invalidate_lock);
+	lockdep_assert_held_write(offsets_lock);
 
 	/*
 	 * If the folio is NONE_SHARED, it indicates that it is transitioning to
@@ -466,16 +469,17 @@ static int kvm_gmem_offset_set_shared(struct inode *inode, pgoff_t index)
 static int kvm_gmem_offset_range_set_shared(struct inode *inode,
 					    pgoff_t start, pgoff_t end)
 {
+	rwlock_t *offsets_lock = &kvm_gmem_private(inode)->offsets_lock;
 	pgoff_t i;
 	int r = 0;
 
-	filemap_invalidate_lock(inode->i_mapping);
+	write_lock(offsets_lock);
 	for (i = start; i < end; i++) {
 		r = kvm_gmem_offset_set_shared(inode, i);
 		if (WARN_ON_ONCE(r))
 			break;
 	}
-	filemap_invalidate_unlock(inode->i_mapping);
+	write_unlock(offsets_lock);
 
 	return r;
 }
@@ -483,13 +487,14 @@ static int kvm_gmem_offset_range_set_shared(struct inode *inode,
 static int kvm_gmem_offset_clear_shared(struct inode *inode, pgoff_t index)
 {
 	struct xarray *shared_offsets = &kvm_gmem_private(inode)->shared_offsets;
+	rwlock_t *offsets_lock = &kvm_gmem_private(inode)->offsets_lock;
 	void *xval_guest = xa_mk_value(KVM_GMEM_GUEST_SHARED);
 	void *xval_none = xa_mk_value(KVM_GMEM_NONE_SHARED);
 	struct folio *folio;
 	int refcount;
 	int r;
 
-	rwsem_assert_held_write_nolockdep(&inode->i_mapping->invalidate_lock);
+	lockdep_assert_held_write(offsets_lock);
 
 	folio = filemap_lock_folio(inode->i_mapping, index);
 	if (!IS_ERR(folio)) {
@@ -530,24 +535,25 @@ static int kvm_gmem_offset_clear_shared(struct inode *inode, pgoff_t index)
 /*
  * Callback when invalidating memory that is potentially shared.
  *
- * Must be called with the filemap (inode->i_mapping) invalidate_lock held.
+ * Must be called with the offsets_lock write lock held.
  */
 static void kvm_gmem_offset_range_invalidate_shared(struct inode *inode,
 						    pgoff_t start, pgoff_t end)
 {
 	struct xarray *shared_offsets = &kvm_gmem_private(inode)->shared_offsets;
+	rwlock_t *offsets_lock = &kvm_gmem_private(inode)->offsets_lock;
 	pgoff_t i;
 
-	rwsem_assert_held_write_nolockdep(&inode->i_mapping->invalidate_lock);
+	lockdep_assert_held_write(offsets_lock);
 
 	for (i = start; i < end; i++) {
 		/*
-		 * If the folio is NONE_SHARED, it indicates that it is
+		 * If the folio is NONE_SHARED, it indicates that it's
 		 * transitioning to private (GUEST_SHARED). Transition it to
 		 * shared (ALL_SHARED) and remove the callback.
 		 */
 		if (xa_to_value(xa_load(shared_offsets, i)) == KVM_GMEM_NONE_SHARED) {
-			struct folio *folio = folio = filemap_lock_folio(inode->i_mapping, i);
+			struct folio *folio = filemap_lock_folio(inode->i_mapping, i);
 
 			if (!WARN_ON_ONCE(IS_ERR(folio))) {
 				if (folio_test_guestmem(folio))
@@ -577,16 +583,17 @@ static void kvm_gmem_offset_range_invalidate_shared(struct inode *inode,
 static int kvm_gmem_offset_range_clear_shared(struct inode *inode,
 					      pgoff_t start, pgoff_t end)
 {
+	rwlock_t *offsets_lock = &kvm_gmem_private(inode)->offsets_lock;
 	pgoff_t i;
 	int r = 0;
 
-	filemap_invalidate_lock(inode->i_mapping);
+	write_lock(offsets_lock);
 	for (i = start; i < end; i++) {
 		r = kvm_gmem_offset_clear_shared(inode, i);
 		if (WARN_ON_ONCE(r))
 			break;
 	}
-	filemap_invalidate_unlock(inode->i_mapping);
+	write_unlock(offsets_lock);
 
 	return r;
 }
@@ -599,17 +606,17 @@ static int kvm_gmem_offset_range_clear_shared(struct inode *inode,
  * Returns 0 if a callback was registered or already has been registered, or
  * -EAGAIN if the host has references, indicating a callback wasn't registered.
  *
- * Must be called with the filemap (inode->i_mapping) invalidate_lock held, and
- * the folio must be locked.
+ * Must be called with the folio locked and the offsets_lock write lock held.
  */
 static int kvm_gmem_register_callback(struct folio *folio, struct inode *inode, pgoff_t index)
 {
 	struct xarray *shared_offsets = &kvm_gmem_private(inode)->shared_offsets;
+	rwlock_t *offsets_lock = &kvm_gmem_private(inode)->offsets_lock;
 	void *xval_guest = xa_mk_value(KVM_GMEM_GUEST_SHARED);
 	int refcount;
 	int r = 0;
 
-	rwsem_assert_held_write_nolockdep(&inode->i_mapping->invalidate_lock);
+	lockdep_assert_held_write(offsets_lock);
 	WARN_ON_ONCE(!folio_test_locked(folio));
 
 	if (folio_test_guestmem(folio)) {
@@ -651,23 +658,23 @@ int kvm_gmem_slot_register_callback(struct kvm_memory_slot *slot, gfn_t gfn)
 {
 	unsigned long pgoff = slot->gmem.pgoff + gfn - slot->base_gfn;
 	struct inode *inode = file_inode(READ_ONCE(slot->gmem.file));
+	rwlock_t *offsets_lock = &kvm_gmem_private(inode)->offsets_lock;
 	struct folio *folio;
 	int r;
 
-	filemap_invalidate_lock(inode->i_mapping);
+	write_lock(offsets_lock);
 
 	folio = filemap_lock_folio(inode->i_mapping, pgoff);
 	if (WARN_ON_ONCE(IS_ERR(folio))) {
-		r = PTR_ERR(folio);
-		goto out;
+		write_unlock(offsets_lock);
+		return PTR_ERR(folio);
 	}
 
 	r = kvm_gmem_register_callback(folio, inode, pgoff);
 
 	folio_unlock(folio);
 	folio_put(folio);
-out:
-	filemap_invalidate_unlock(inode->i_mapping);
+	write_unlock(offsets_lock);
 
 	return r;
 }
@@ -684,6 +691,7 @@ void kvm_gmem_handle_folio_put(struct folio *folio)
 {
 	struct address_space *mapping;
 	struct xarray *shared_offsets;
+	rwlock_t *offsets_lock;
 	struct inode *inode;
 	pgoff_t index;
 	void *xval;
@@ -695,37 +703,50 @@ void kvm_gmem_handle_folio_put(struct folio *folio)
 	inode = mapping->host;
 	index = folio->index;
 	shared_offsets = &kvm_gmem_private(inode)->shared_offsets;
+	offsets_lock = &kvm_gmem_private(inode)->offsets_lock;
 	xval = xa_mk_value(KVM_GMEM_GUEST_SHARED);
 
-	filemap_invalidate_lock(inode->i_mapping);
+	write_lock(offsets_lock);
 	folio_lock(folio);
 	kvm_gmem_restore_pending_folio(folio, inode);
 	folio_unlock(folio);
 	WARN_ON_ONCE(xa_err(xa_store(shared_offsets, index, xval, GFP_KERNEL)));
-	filemap_invalidate_unlock(inode->i_mapping);
+	write_unlock(offsets_lock);
 
 	pr_info("%s: done\n", __func__);
 }
 EXPORT_SYMBOL_GPL(kvm_gmem_handle_folio_put);
 
+/*
+ * Returns true if the folio is shared with the host and the guest.
+ *
+ * Must be called with the offsets_lock lock held.
+ */
 static bool kvm_gmem_offset_is_shared(struct inode *inode, pgoff_t index)
 {
 	struct xarray *shared_offsets = &kvm_gmem_private(inode)->shared_offsets;
+	rwlock_t *offsets_lock = &kvm_gmem_private(inode)->offsets_lock;
 	unsigned long r;
 
-	rwsem_assert_held_nolockdep(&inode->i_mapping->invalidate_lock);
+	lockdep_assert_held(offsets_lock);
 
 	r = xa_to_value(xa_load(shared_offsets, index));
 
 	return r == KVM_GMEM_ALL_SHARED;
 }
 
+/*
+ * Returns true if the folio is shared with the guest (not transitioning).
+ *
+ * Must be called with the offsets_lock lock held.
+ */
 static bool kvm_gmem_offset_is_guest_shared(struct inode *inode, pgoff_t index)
 {
 	struct xarray *shared_offsets = &kvm_gmem_private(inode)->shared_offsets;
+	rwlock_t *offsets_lock = &kvm_gmem_private(inode)->offsets_lock;
 	unsigned long r;
 
-	rwsem_assert_held_nolockdep(&inode->i_mapping->invalidate_lock);
+	lockdep_assert_held(offsets_lock);
 
 	r = xa_to_value(xa_load(shared_offsets, index));
 
@@ -753,12 +774,13 @@ int kvm_gmem_slot_clear_shared(struct kvm_memory_slot *slot, gfn_t start, gfn_t
 bool kvm_gmem_slot_is_guest_shared(struct kvm_memory_slot *slot, gfn_t gfn)
 {
 	struct inode *inode = file_inode(READ_ONCE(slot->gmem.file));
+	rwlock_t *offsets_lock = &kvm_gmem_private(inode)->offsets_lock;
 	unsigned long pgoff = slot->gmem.pgoff + gfn - slot->base_gfn;
 	bool r;
 
-	filemap_invalidate_lock_shared(inode->i_mapping);
+	read_lock(offsets_lock);
 	r = kvm_gmem_offset_is_guest_shared(inode, pgoff);
-	filemap_invalidate_unlock_shared(inode->i_mapping);
+	read_unlock(offsets_lock);
 
 	return r;
 }
@@ -766,10 +788,12 @@ bool kvm_gmem_slot_is_guest_shared(struct kvm_memory_slot *slot, gfn_t gfn)
 static vm_fault_t kvm_gmem_fault(struct vm_fault *vmf)
 {
 	struct inode *inode = file_inode(vmf->vma->vm_file);
+	rwlock_t *offsets_lock = &kvm_gmem_private(inode)->offsets_lock;
 	struct folio *folio;
 	vm_fault_t ret = VM_FAULT_LOCKED;
 
 	filemap_invalidate_lock_shared(inode->i_mapping);
+	read_lock(offsets_lock);
 
 	folio = kvm_gmem_get_folio(inode, vmf->pgoff);
 	if (IS_ERR(folio)) {
@@ -822,6 +846,7 @@ static vm_fault_t kvm_gmem_fault(struct vm_fault *vmf)
 	}
 
 out_filemap:
+	read_unlock(offsets_lock);
 	filemap_invalidate_unlock_shared(inode->i_mapping);
 
 	return ret;
@@ -969,6 +994,7 @@ static struct inode *kvm_gmem_inode_make_secure_inode(const char *name,
 
 #ifdef CONFIG_KVM_GMEM_SHARED_MEM
 	xa_init(&private->shared_offsets);
+	rwlock_init(&private->offsets_lock);
 #endif
 
 	inode->i_mapping->i_private_data = private;