mm/hmm: allow hmm_range to be used with a mmu_interval_notifier or hmm_mirror

hmm_mirror's handling of ranges does not use a sequence count which
results in this bug:

         CPU0                                   CPU1
                                     hmm_range_wait_until_valid(range)
                                         valid == true
                                     hmm_range_fault(range)
hmm_invalidate_range_start()
   range->valid = false
hmm_invalidate_range_end()
   range->valid = true
                                     hmm_range_valid(range)
                                          valid == true

Where the hmm_range_valid() should not have succeeded.

Adding the required sequence count would make it nearly identical to the
new mmu_interval_notifier. Instead replace the hmm_mirror stuff with
mmu_interval_notifier.

Co-existence of the two APIs is the first step.

Link: https://lore.kernel.org/r/20191112202231.3856-4-jgg@ziepe.ca
Reviewed-by: Jérôme Glisse <jglisse@redhat.com>
Tested-by: Philip Yang <Philip.Yang@amd.com>
Tested-by: Ralph Campbell <rcampbell@nvidia.com>
Reviewed-by: Christoph Hellwig <hch@lst.de>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
diff --git a/mm/hmm.c b/mm/hmm.c
index 6b01366..8d060c5 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -858,6 +858,14 @@ void hmm_range_unregister(struct hmm_range *range)
 }
 EXPORT_SYMBOL(hmm_range_unregister);
 
+static bool needs_retry(struct hmm_range *range)
+{
+	if (range->notifier)
+		return mmu_interval_check_retry(range->notifier,
+						range->notifier_seq);
+	return !range->valid;
+}
+
 static const struct mm_walk_ops hmm_walk_ops = {
 	.pud_entry	= hmm_vma_walk_pud,
 	.pmd_entry	= hmm_vma_walk_pmd,
@@ -898,18 +906,23 @@ long hmm_range_fault(struct hmm_range *range, unsigned int flags)
 	const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
 	unsigned long start = range->start, end;
 	struct hmm_vma_walk hmm_vma_walk;
-	struct hmm *hmm = range->hmm;
+	struct mm_struct *mm;
 	struct vm_area_struct *vma;
 	int ret;
 
-	lockdep_assert_held(&hmm->mmu_notifier.mm->mmap_sem);
+	if (range->notifier)
+		mm = range->notifier->mm;
+	else
+		mm = range->hmm->mmu_notifier.mm;
+
+	lockdep_assert_held(&mm->mmap_sem);
 
 	do {
 		/* If range is no longer valid force retry. */
-		if (!range->valid)
+		if (needs_retry(range))
 			return -EBUSY;
 
-		vma = find_vma(hmm->mmu_notifier.mm, start);
+		vma = find_vma(mm, start);
 		if (vma == NULL || (vma->vm_flags & device_vma))
 			return -EFAULT;
 
@@ -939,7 +952,7 @@ long hmm_range_fault(struct hmm_range *range, unsigned int flags)
 			start = hmm_vma_walk.last;
 
 			/* Keep trying while the range is valid. */
-		} while (ret == -EBUSY && range->valid);
+		} while (ret == -EBUSY && !needs_retry(range));
 
 		if (ret) {
 			unsigned long i;
@@ -997,7 +1010,7 @@ long hmm_range_dma_map(struct hmm_range *range, struct device *device,
 			continue;
 
 		/* Check if range is being invalidated */
-		if (!range->valid) {
+		if (needs_retry(range)) {
 			ret = -EBUSY;
 			goto unmap;
 		}