Merge tag 'v5.2-rc7' into rdma.git hmm

Required for dependencies in the next patches.
diff --git a/Documentation/vm/hmm.rst b/Documentation/vm/hmm.rst
index 7cdf728..7b6eeda 100644
--- a/Documentation/vm/hmm.rst
+++ b/Documentation/vm/hmm.rst
@@ -10,7 +10,7 @@
 this document).
 
 HMM also provides optional helpers for SVM (Share Virtual Memory), i.e.,
-allowing a device to transparently access program address coherently with
+allowing a device to transparently access program addresses coherently with
 the CPU meaning that any valid pointer on the CPU is also a valid pointer
 for the device. This is becoming mandatory to simplify the use of advanced
 heterogeneous computing where GPU, DSP, or FPGA are used to perform various
@@ -22,8 +22,8 @@
 section gives an overview of the HMM design. The fourth section explains how
 CPU page-table mirroring works and the purpose of HMM in this context. The
 fifth section deals with how device memory is represented inside the kernel.
-Finally, the last section presents a new migration helper that allows lever-
-aging the device DMA engine.
+Finally, the last section presents a new migration helper that allows
+leveraging the device DMA engine.
 
 .. contents:: :local:
 
@@ -39,20 +39,20 @@
 i.e., one in which any application memory region can be used by a device
 transparently.
 
-Split address space happens because device can only access memory allocated
-through device specific API. This implies that all memory objects in a program
+Split address space happens because devices can only access memory allocated
+through a device specific API. This implies that all memory objects in a program
 are not equal from the device point of view which complicates large programs
 that rely on a wide set of libraries.
 
-Concretely this means that code that wants to leverage devices like GPUs needs
-to copy object between generically allocated memory (malloc, mmap private, mmap
+Concretely, this means that code that wants to leverage devices like GPUs needs
+to copy objects between generically allocated memory (malloc, mmap private, mmap
 share) and memory allocated through the device driver API (this still ends up
 with an mmap but of the device file).
 
 For flat data sets (array, grid, image, ...) this isn't too hard to achieve but
-complex data sets (list, tree, ...) are hard to get right. Duplicating a
+for complex data sets (list, tree, ...) it's hard to get right. Duplicating a
 complex data set needs to re-map all the pointer relations between each of its
-elements. This is error prone and program gets harder to debug because of the
+elements. This is error prone and programs get harder to debug because of the
 duplicate data set and addresses.
 
 Split address space also means that libraries cannot transparently use data
@@ -77,12 +77,12 @@
 
 I/O buses cripple shared address spaces due to a few limitations. Most I/O
 buses only allow basic memory access from device to main memory; even cache
-coherency is often optional. Access to device memory from CPU is even more
+coherency is often optional. Access to device memory from a CPU is even more
 limited. More often than not, it is not cache coherent.
 
 If we only consider the PCIE bus, then a device can access main memory (often
 through an IOMMU) and be cache coherent with the CPUs. However, it only allows
-a limited set of atomic operations from device on main memory. This is worse
+a limited set of atomic operations from the device on main memory. This is worse
 in the other direction: the CPU can only access a limited range of the device
 memory and cannot perform atomic operations on it. Thus device memory cannot
 be considered the same as regular memory from the kernel point of view.
@@ -93,20 +93,20 @@
 order of magnitude higher latency than when the device accesses its own memory.
 
 Some platforms are developing new I/O buses or additions/modifications to PCIE
-to address some of these limitations (OpenCAPI, CCIX). They mainly allow two-
-way cache coherency between CPU and device and allow all atomic operations the
+to address some of these limitations (OpenCAPI, CCIX). They mainly allow
+two-way cache coherency between CPU and device and allow all atomic operations the
 architecture supports. Sadly, not all platforms are following this trend and
 some major architectures are left without hardware solutions to these problems.
 
 So for shared address space to make sense, not only must we allow devices to
 access any memory but we must also permit any memory to be migrated to device
-memory while device is using it (blocking CPU access while it happens).
+memory while the device is using it (blocking CPU access while it happens).
 
 
 Shared address space and migration
 ==================================
 
-HMM intends to provide two main features. First one is to share the address
+HMM intends to provide two main features. The first one is to share the address
 space by duplicating the CPU page table in the device page table so the same
 address points to the same physical memory for any valid main memory address in
 the process address space.
@@ -121,14 +121,14 @@
 hardware specific details to the device driver.
 
 The second mechanism HMM provides is a new kind of ZONE_DEVICE memory that
-allows allocating a struct page for each page of the device memory. Those pages
+allows allocating a struct page for each page of device memory. Those pages
 are special because the CPU cannot map them. However, they allow migrating
 main memory to device memory using existing migration mechanisms and everything
-looks like a page is swapped out to disk from the CPU point of view. Using a
-struct page gives the easiest and cleanest integration with existing mm mech-
-anisms. Here again, HMM only provides helpers, first to hotplug new ZONE_DEVICE
+looks like a page that is swapped out to disk from the CPU point of view. Using a
+struct page gives the easiest and cleanest integration with existing mm
+mechanisms. Here again, HMM only provides helpers, first to hotplug new ZONE_DEVICE
 memory for the device memory and second to perform migration. Policy decisions
-of what and when to migrate things is left to the device driver.
+of what and when to migrate is left to the device driver.
 
 Note that any CPU access to a device page triggers a page fault and a migration
 back to main memory. For example, when a page backing a given CPU address A is
@@ -136,8 +136,8 @@
 address A triggers a page fault and initiates a migration back to main memory.
 
 With these two features, HMM not only allows a device to mirror process address
-space and keeping both CPU and device page table synchronized, but also lever-
-ages device memory by migrating the part of the data set that is actively being
+space and keeps both CPU and device page tables synchronized, but also
+leverages device memory by migrating the part of the data set that is actively being
 used by the device.
 
 
@@ -151,21 +151,28 @@
 
  int hmm_mirror_register(struct hmm_mirror *mirror,
                          struct mm_struct *mm);
- int hmm_mirror_register_locked(struct hmm_mirror *mirror,
-                                struct mm_struct *mm);
 
-
-The locked variant is to be used when the driver is already holding mmap_sem
-of the mm in write mode. The mirror struct has a set of callbacks that are used
+The mirror struct has a set of callbacks that are used
 to propagate CPU page tables::
 
  struct hmm_mirror_ops {
+     /* release() - release hmm_mirror
+      *
+      * @mirror: pointer to struct hmm_mirror
+      *
+      * This is called when the mm_struct is being released.  The callback
+      * must ensure that all access to any pages obtained from this mirror
+      * is halted before the callback returns. All future access should
+      * fault.
+      */
+     void (*release)(struct hmm_mirror *mirror);
+
      /* sync_cpu_device_pagetables() - synchronize page tables
       *
       * @mirror: pointer to struct hmm_mirror
-      * @update_type: type of update that occurred to the CPU page table
-      * @start: virtual start address of the range to update
-      * @end: virtual end address of the range to update
+      * @update: update information (see struct mmu_notifier_range)
+      * Return: -EAGAIN if update.blockable false and callback need to
+      *         block, 0 otherwise.
       *
       * This callback ultimately originates from mmu_notifiers when the CPU
       * page table is updated. The device driver must update its page table
@@ -176,14 +183,12 @@
       * page tables are completely updated (TLBs flushed, etc); this is a
       * synchronous call.
       */
-      void (*update)(struct hmm_mirror *mirror,
-                     enum hmm_update action,
-                     unsigned long start,
-                     unsigned long end);
+     int (*sync_cpu_device_pagetables)(struct hmm_mirror *mirror,
+                                       const struct hmm_update *update);
  };
 
 The device driver must perform the update action to the range (mark range
-read only, or fully unmap, ...). The device must be done with the update before
+read only, or fully unmap, etc.). The device must complete the update before
 the driver callback returns.
 
 When the device driver wants to populate a range of virtual addresses, it can
@@ -194,17 +199,18 @@
 
 The first one (hmm_range_snapshot()) will only fetch present CPU page table
 entries and will not trigger a page fault on missing or non-present entries.
-The second one does trigger a page fault on missing or read-only entry if the
-write parameter is true. Page faults use the generic mm page fault code path
-just like a CPU page fault.
+The second one does trigger a page fault on missing or read-only entries if
+write access is requested (see below). Page faults use the generic mm page
+fault code path just like a CPU page fault.
 
 Both functions copy CPU page table entries into their pfns array argument. Each
 entry in that array corresponds to an address in the virtual range. HMM
 provides a set of flags to help the driver identify special CPU page table
 entries.
 
-Locking with the update() callback is the most important aspect the driver must
-respect in order to keep things properly synchronized. The usage pattern is::
+Locking within the sync_cpu_device_pagetables() callback is the most important
+aspect the driver must respect in order to keep things properly synchronized.
+The usage pattern is::
 
  int driver_populate_range(...)
  {
@@ -239,11 +245,11 @@
             hmm_range_wait_until_valid(&range, TIMEOUT_IN_MSEC);
             goto again;
           }
-          hmm_mirror_unregister(&range);
+          hmm_range_unregister(&range);
           return ret;
       }
       take_lock(driver->update);
-      if (!range.valid) {
+      if (!hmm_range_valid(&range)) {
           release_lock(driver->update);
           up_read(&mm->mmap_sem);
           goto again;
@@ -251,15 +257,15 @@
 
       // Use pfns array content to update device page table
 
-      hmm_mirror_unregister(&range);
+      hmm_range_unregister(&range);
       release_lock(driver->update);
       up_read(&mm->mmap_sem);
       return 0;
  }
 
 The driver->update lock is the same lock that the driver takes inside its
-update() callback. That lock must be held before checking the range.valid
-field to avoid any race with a concurrent CPU page table update.
+sync_cpu_device_pagetables() callback. That lock must be held before calling
+hmm_range_valid() to avoid any race with a concurrent CPU page table update.
 
 HMM implements all this on top of the mmu_notifier API because we wanted a
 simpler API and also to be able to perform optimizations latter on like doing
@@ -279,46 +285,47 @@
 Leverage default_flags and pfn_flags_mask
 =========================================
 
-The hmm_range struct has 2 fields default_flags and pfn_flags_mask that allows
-to set fault or snapshot policy for a whole range instead of having to set them
-for each entries in the range.
+The hmm_range struct has 2 fields, default_flags and pfn_flags_mask, that specify
+fault or snapshot policy for the whole range instead of having to set them
+for each entry in the pfns array.
 
-For instance if the device flags for device entries are:
-    VALID (1 << 63)
-    WRITE (1 << 62)
+For instance, if the device flags for range.flags are::
 
-Now let say that device driver wants to fault with at least read a range then
-it does set::
+    range.flags[HMM_PFN_VALID] = (1 << 63);
+    range.flags[HMM_PFN_WRITE] = (1 << 62);
+
+and the device driver wants pages for a range with at least read permission,
+it sets::
 
     range->default_flags = (1 << 63);
     range->pfn_flags_mask = 0;
 
-and calls hmm_range_fault() as described above. This will fill fault all page
+and calls hmm_range_fault() as described above. This will fill fault all pages
 in the range with at least read permission.
 
-Now let say driver wants to do the same except for one page in the range for
-which its want to have write. Now driver set::
+Now let's say the driver wants to do the same except for one page in the range for
+which it wants to have write permission. Now driver set::
 
     range->default_flags = (1 << 63);
     range->pfn_flags_mask = (1 << 62);
     range->pfns[index_of_write] = (1 << 62);
 
-With this HMM will fault in all page with at least read (ie valid) and for the
+With this, HMM will fault in all pages with at least read (i.e., valid) and for the
 address == range->start + (index_of_write << PAGE_SHIFT) it will fault with
-write permission ie if the CPU pte does not have write permission set then HMM
+write permission i.e., if the CPU pte does not have write permission set then HMM
 will call handle_mm_fault().
 
-Note that HMM will populate the pfns array with write permission for any entry
-that have write permission within the CPU pte no matter what are the values set
+Note that HMM will populate the pfns array with write permission for any page
+that is mapped with CPU write permission no matter what values are set
 in default_flags or pfn_flags_mask.
 
 
 Represent and manage device memory from core kernel point of view
 =================================================================
 
-Several different designs were tried to support device memory. First one used
-a device specific data structure to keep information about migrated memory and
-HMM hooked itself in various places of mm code to handle any access to
+Several different designs were tried to support device memory. The first one
+used a device specific data structure to keep information about migrated memory
+and HMM hooked itself in various places of mm code to handle any access to
 addresses that were backed by device memory. It turns out that this ended up
 replicating most of the fields of struct page and also needed many kernel code
 paths to be updated to understand this new kind of memory.
@@ -341,7 +348,7 @@
 
  struct hmm_devmem_ops {
      void (*free)(struct hmm_devmem *devmem, struct page *page);
-     int (*fault)(struct hmm_devmem *devmem,
+     vm_fault_t (*fault)(struct hmm_devmem *devmem,
                   struct vm_area_struct *vma,
                   unsigned long addr,
                   struct page *page,
@@ -417,9 +424,9 @@
 Memory cgroup (memcg) and rss accounting
 ========================================
 
-For now device memory is accounted as any regular page in rss counters (either
+For now, device memory is accounted as any regular page in rss counters (either
 anonymous if device page is used for anonymous, file if device page is used for
-file backed page or shmem if device page is used for shared memory). This is a
+file backed page, or shmem if device page is used for shared memory). This is a
 deliberate choice to keep existing applications, that might start using device
 memory without knowing about it, running unimpacted.
 
@@ -439,6 +446,6 @@
 resource control.
 
 
-Note that device memory can never be pinned by device driver nor through GUP
+Note that device memory can never be pinned by a device driver nor through GUP
 and thus such memory is always free upon process exit. Or when last reference
 is dropped in case of shared memory or file backed memory.
diff --git a/drivers/gpu/drm/nouveau/nouveau_svm.c b/drivers/gpu/drm/nouveau/nouveau_svm.c
index 93ed43c..8c92374 100644
--- a/drivers/gpu/drm/nouveau/nouveau_svm.c
+++ b/drivers/gpu/drm/nouveau/nouveau_svm.c
@@ -649,7 +649,7 @@ nouveau_svm_fault(struct nvif_notify *notify)
 		range.values = nouveau_svm_pfn_values;
 		range.pfn_shift = NVIF_VMM_PFNMAP_V0_ADDR_SHIFT;
 again:
-		ret = hmm_vma_fault(&range, true);
+		ret = hmm_vma_fault(&svmm->mirror, &range, true);
 		if (ret == 0) {
 			mutex_lock(&svmm->mutex);
 			if (!hmm_vma_range_done(&range)) {
diff --git a/include/linux/hmm.h b/include/linux/hmm.h
index 044a36d..0fa8ea3 100644
--- a/include/linux/hmm.h
+++ b/include/linux/hmm.h
@@ -21,8 +21,8 @@
  *
  * HMM address space mirroring API:
  *
- * Use HMM address space mirroring if you want to mirror range of the CPU page
- * table of a process into a device page table. Here, "mirror" means "keep
+ * Use HMM address space mirroring if you want to mirror a range of the CPU
+ * page tables of a process into a device page table. Here, "mirror" means "keep
  * synchronized". Prerequisites: the device must provide the ability to write-
  * protect its page tables (at PAGE_SIZE granularity), and must be able to
  * recover from the resulting potential page faults.
@@ -82,19 +82,18 @@
  * @mirrors_sem: read/write semaphore protecting the mirrors list
  * @wq: wait queue for user waiting on a range invalidation
  * @notifiers: count of active mmu notifiers
- * @dead: is the mm dead ?
  */
 struct hmm {
 	struct mm_struct	*mm;
 	struct kref		kref;
-	struct mutex		lock;
+	spinlock_t		ranges_lock;
 	struct list_head	ranges;
 	struct list_head	mirrors;
 	struct mmu_notifier	mmu_notifier;
 	struct rw_semaphore	mirrors_sem;
 	wait_queue_head_t	wq;
+	struct rcu_head		rcu;
 	long			notifiers;
-	bool			dead;
 };
 
 /*
@@ -105,10 +104,11 @@ struct hmm {
  * HMM_PFN_WRITE: CPU page table has write permission set
  * HMM_PFN_DEVICE_PRIVATE: private device memory (ZONE_DEVICE)
  *
- * The driver provide a flags array, if driver valid bit for an entry is bit
- * 3 ie (entry & (1 << 3)) is true if entry is valid then driver must provide
+ * The driver provides a flags array for mapping page protections to device
+ * PTE bits. If the driver valid bit for an entry is bit 3,
+ * i.e., (entry & (1 << 3)), then the driver must provide
  * an array in hmm_range.flags with hmm_range.flags[HMM_PFN_VALID] == 1 << 3.
- * Same logic apply to all flags. This is same idea as vm_page_prot in vma
+ * Same logic apply to all flags. This is the same idea as vm_page_prot in vma
  * except that this is per device driver rather than per architecture.
  */
 enum hmm_pfn_flag_e {
@@ -129,13 +129,13 @@ enum hmm_pfn_flag_e {
  *      be mirrored by a device, because the entry will never have HMM_PFN_VALID
  *      set and the pfn value is undefined.
  *
- * Driver provide entry value for none entry, error entry and special entry,
- * driver can alias (ie use same value for error and special for instance). It
- * should not alias none and error or special.
+ * Driver provides values for none entry, error entry, and special entry.
+ * Driver can alias (i.e., use same value) error and special, but
+ * it should not alias none with error or special.
  *
  * HMM pfn value returned by hmm_vma_get_pfns() or hmm_vma_fault() will be:
  * hmm_range.values[HMM_PFN_ERROR] if CPU page table entry is poisonous,
- * hmm_range.values[HMM_PFN_NONE] if there is no CPU page table
+ * hmm_range.values[HMM_PFN_NONE] if there is no CPU page table entry,
  * hmm_range.values[HMM_PFN_SPECIAL] if CPU page table entry is a special one
  */
 enum hmm_pfn_value_e {
@@ -158,6 +158,7 @@ enum hmm_pfn_value_e {
  * @values: pfn value for some special case (none, special, error, ...)
  * @default_flags: default flags for the range (write, read, ... see hmm doc)
  * @pfn_flags_mask: allows to mask pfn flags so that only default_flags matter
+ * @page_shift: device virtual address shift value (should be >= PAGE_SHIFT)
  * @pfn_shifts: pfn shift value (should be <= PAGE_SHIFT)
  * @valid: pfns array did not change since it has been fill by an HMM function
  */
@@ -180,7 +181,7 @@ struct hmm_range {
 /*
  * hmm_range_page_shift() - return the page shift for the range
  * @range: range being queried
- * Returns: page shift (page size = 1 << page shift) for the range
+ * Return: page shift (page size = 1 << page shift) for the range
  */
 static inline unsigned hmm_range_page_shift(const struct hmm_range *range)
 {
@@ -190,7 +191,7 @@ static inline unsigned hmm_range_page_shift(const struct hmm_range *range)
 /*
  * hmm_range_page_size() - return the page size for the range
  * @range: range being queried
- * Returns: page size for the range in bytes
+ * Return: page size for the range in bytes
  */
 static inline unsigned long hmm_range_page_size(const struct hmm_range *range)
 {
@@ -201,28 +202,19 @@ static inline unsigned long hmm_range_page_size(const struct hmm_range *range)
  * hmm_range_wait_until_valid() - wait for range to be valid
  * @range: range affected by invalidation to wait on
  * @timeout: time out for wait in ms (ie abort wait after that period of time)
- * Returns: true if the range is valid, false otherwise.
+ * Return: true if the range is valid, false otherwise.
  */
 static inline bool hmm_range_wait_until_valid(struct hmm_range *range,
 					      unsigned long timeout)
 {
-	/* Check if mm is dead ? */
-	if (range->hmm == NULL || range->hmm->dead || range->hmm->mm == NULL) {
-		range->valid = false;
-		return false;
-	}
-	if (range->valid)
-		return true;
-	wait_event_timeout(range->hmm->wq, range->valid || range->hmm->dead,
-			   msecs_to_jiffies(timeout));
-	/* Return current valid status just in case we get lucky */
-	return range->valid;
+	return wait_event_timeout(range->hmm->wq, range->valid,
+				  msecs_to_jiffies(timeout)) != 0;
 }
 
 /*
  * hmm_range_valid() - test if a range is valid or not
  * @range: range
- * Returns: true if the range is valid, false otherwise.
+ * Return: true if the range is valid, false otherwise.
  */
 static inline bool hmm_range_valid(struct hmm_range *range)
 {
@@ -233,7 +225,7 @@ static inline bool hmm_range_valid(struct hmm_range *range)
  * hmm_device_entry_to_page() - return struct page pointed to by a device entry
  * @range: range use to decode device entry value
  * @entry: device entry value to get corresponding struct page from
- * Returns: struct page pointer if entry is a valid, NULL otherwise
+ * Return: struct page pointer if entry is a valid, NULL otherwise
  *
  * If the device entry is valid (ie valid flag set) then return the struct page
  * matching the entry value. Otherwise return NULL.
@@ -256,7 +248,7 @@ static inline struct page *hmm_device_entry_to_page(const struct hmm_range *rang
  * hmm_device_entry_to_pfn() - return pfn value store in a device entry
  * @range: range use to decode device entry value
  * @entry: device entry to extract pfn from
- * Returns: pfn value if device entry is valid, -1UL otherwise
+ * Return: pfn value if device entry is valid, -1UL otherwise
  */
 static inline unsigned long
 hmm_device_entry_to_pfn(const struct hmm_range *range, uint64_t pfn)
@@ -276,7 +268,7 @@ hmm_device_entry_to_pfn(const struct hmm_range *range, uint64_t pfn)
  * hmm_device_entry_from_page() - create a valid device entry for a page
  * @range: range use to encode HMM pfn value
  * @page: page for which to create the device entry
- * Returns: valid device entry for the page
+ * Return: valid device entry for the page
  */
 static inline uint64_t hmm_device_entry_from_page(const struct hmm_range *range,
 						  struct page *page)
@@ -289,7 +281,7 @@ static inline uint64_t hmm_device_entry_from_page(const struct hmm_range *range,
  * hmm_device_entry_from_pfn() - create a valid device entry value from pfn
  * @range: range use to encode HMM pfn value
  * @pfn: pfn value for which to create the device entry
- * Returns: valid device entry for the pfn
+ * Return: valid device entry for the pfn
  */
 static inline uint64_t hmm_device_entry_from_pfn(const struct hmm_range *range,
 						 unsigned long pfn)
@@ -394,7 +386,7 @@ enum hmm_update_event {
 };
 
 /*
- * struct hmm_update - HMM update informations for callback
+ * struct hmm_update - HMM update information for callback
  *
  * @start: virtual start address of the range to update
  * @end: virtual end address of the range to update
@@ -418,17 +410,18 @@ struct hmm_mirror_ops {
 	 *
 	 * @mirror: pointer to struct hmm_mirror
 	 *
-	 * This is called when the mm_struct is being released.
-	 * The callback should make sure no references to the mirror occur
-	 * after the callback returns.
+	 * This is called when the mm_struct is being released.  The callback
+	 * must ensure that all access to any pages obtained from this mirror
+	 * is halted before the callback returns. All future access should
+	 * fault.
 	 */
 	void (*release)(struct hmm_mirror *mirror);
 
 	/* sync_cpu_device_pagetables() - synchronize page tables
 	 *
 	 * @mirror: pointer to struct hmm_mirror
-	 * @update: update informations (see struct hmm_update)
-	 * Returns: -EAGAIN if update.blockable false and callback need to
+	 * @update: update information (see struct hmm_update)
+	 * Return: -EAGAIN if update.blockable false and callback need to
 	 *          block, 0 otherwise.
 	 *
 	 * This callback ultimately originates from mmu_notifiers when the CPU
@@ -465,35 +458,10 @@ int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm);
 void hmm_mirror_unregister(struct hmm_mirror *mirror);
 
 /*
- * hmm_mirror_mm_is_alive() - test if mm is still alive
- * @mirror: the HMM mm mirror for which we want to lock the mmap_sem
- * Returns: false if the mm is dead, true otherwise
- *
- * This is an optimization it will not accurately always return -EINVAL if the
- * mm is dead ie there can be false negative (process is being kill but HMM is
- * not yet inform of that). It is only intented to be use to optimize out case
- * where driver is about to do something time consuming and it would be better
- * to skip it if the mm is dead.
- */
-static inline bool hmm_mirror_mm_is_alive(struct hmm_mirror *mirror)
-{
-	struct mm_struct *mm;
-
-	if (!mirror || !mirror->hmm)
-		return false;
-	mm = READ_ONCE(mirror->hmm->mm);
-	if (mirror->hmm->dead || !mm)
-		return false;
-
-	return true;
-}
-
-
-/*
  * Please see Documentation/vm/hmm.rst for how to use the range API.
  */
 int hmm_range_register(struct hmm_range *range,
-		       struct mm_struct *mm,
+		       struct hmm_mirror *mirror,
 		       unsigned long start,
 		       unsigned long end,
 		       unsigned page_shift);
@@ -529,7 +497,8 @@ static inline bool hmm_vma_range_done(struct hmm_range *range)
 }
 
 /* This is a temporary helper to avoid merge conflict between trees. */
-static inline int hmm_vma_fault(struct hmm_range *range, bool block)
+static inline int hmm_vma_fault(struct hmm_mirror *mirror,
+				struct hmm_range *range, bool block)
 {
 	long ret;
 
@@ -542,7 +511,7 @@ static inline int hmm_vma_fault(struct hmm_range *range, bool block)
 	range->default_flags = 0;
 	range->pfn_flags_mask = -1UL;
 
-	ret = hmm_range_register(range, range->vma->vm_mm,
+	ret = hmm_range_register(range, mirror,
 				 range->start, range->end,
 				 PAGE_SHIFT);
 	if (ret)
@@ -561,7 +530,7 @@ static inline int hmm_vma_fault(struct hmm_range *range, bool block)
 	ret = hmm_range_fault(range, block);
 	if (ret <= 0) {
 		if (ret == -EBUSY || !ret) {
-			/* Same as above  drop mmap_sem to match old API. */
+			/* Same as above, drop mmap_sem to match old API. */
 			up_read(&range->vma->vm_mm->mmap_sem);
 			ret = -EBUSY;
 		} else if (ret == -EAGAIN)
@@ -573,14 +542,11 @@ static inline int hmm_vma_fault(struct hmm_range *range, bool block)
 }
 
 /* Below are for HMM internal use only! Not to be used by device driver! */
-void hmm_mm_destroy(struct mm_struct *mm);
-
 static inline void hmm_mm_init(struct mm_struct *mm)
 {
 	mm->hmm = NULL;
 }
 #else /* IS_ENABLED(CONFIG_HMM_MIRROR) */
-static inline void hmm_mm_destroy(struct mm_struct *mm) {}
 static inline void hmm_mm_init(struct mm_struct *mm) {}
 #endif /* IS_ENABLED(CONFIG_HMM_MIRROR) */
 
@@ -628,7 +594,7 @@ struct hmm_devmem_ops {
 	 * @page: pointer to struct page backing virtual address (unreliable)
 	 * @flags: FAULT_FLAG_* (see include/linux/mm.h)
 	 * @pmdp: page middle directory
-	 * Returns: VM_FAULT_MINOR/MAJOR on success or one of VM_FAULT_ERROR
+	 * Return: VM_FAULT_MINOR/MAJOR on success or one of VM_FAULT_ERROR
 	 *   on error
 	 *
 	 * The callback occurs whenever there is a CPU page fault or GUP on a
@@ -636,14 +602,14 @@ struct hmm_devmem_ops {
 	 * page back to regular memory (CPU accessible).
 	 *
 	 * The device driver is free to migrate more than one page from the
-	 * fault() callback as an optimization. However if device decide to
-	 * migrate more than one page it must always priotirize the faulting
+	 * fault() callback as an optimization. However if the device decides
+	 * to migrate more than one page it must always priotirize the faulting
 	 * address over the others.
 	 *
-	 * The struct page pointer is only given as an hint to allow quick
+	 * The struct page pointer is only given as a hint to allow quick
 	 * lookup of internal device driver data. A concurrent migration
-	 * might have already free that page and the virtual address might
-	 * not longer be back by it. So it should not be modified by the
+	 * might have already freed that page and the virtual address might
+	 * no longer be backed by it. So it should not be modified by the
 	 * callback.
 	 *
 	 * Note that mmap semaphore is held in read mode at least when this
@@ -670,7 +636,7 @@ struct hmm_devmem_ops {
  * @ref: per CPU refcount
  * @page_fault: callback when CPU fault on an unaddressable device page
  *
- * This an helper structure for device drivers that do not wish to implement
+ * This is a helper structure for device drivers that do not wish to implement
  * the gory details related to hotplugging new memoy and allocating struct
  * pages.
  *
diff --git a/kernel/fork.c b/kernel/fork.c
index 6166790..129de58 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -677,7 +677,6 @@ void __mmdrop(struct mm_struct *mm)
 	WARN_ON_ONCE(mm == current->active_mm);
 	mm_free_pgd(mm);
 	destroy_context(mm);
-	hmm_mm_destroy(mm);
 	mmu_notifier_mm_destroy(mm);
 	check_mm(mm);
 	put_user_ns(mm->user_ns);
diff --git a/mm/hmm.c b/mm/hmm.c
index f702a38..c1bdcef 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -20,6 +20,7 @@
 #include <linux/swapops.h>
 #include <linux/hugetlb.h>
 #include <linux/memremap.h>
+#include <linux/sched/mm.h>
 #include <linux/jump_label.h>
 #include <linux/dma-mapping.h>
 #include <linux/mmu_notifier.h>
@@ -30,16 +31,6 @@
 #if IS_ENABLED(CONFIG_HMM_MIRROR)
 static const struct mmu_notifier_ops hmm_mmu_notifier_ops;
 
-static inline struct hmm *mm_get_hmm(struct mm_struct *mm)
-{
-	struct hmm *hmm = READ_ONCE(mm->hmm);
-
-	if (hmm && kref_get_unless_zero(&hmm->kref))
-		return hmm;
-
-	return NULL;
-}
-
 /**
  * hmm_get_or_create - register HMM against an mm (HMM internal)
  *
@@ -54,11 +45,16 @@ static inline struct hmm *mm_get_hmm(struct mm_struct *mm)
  */
 static struct hmm *hmm_get_or_create(struct mm_struct *mm)
 {
-	struct hmm *hmm = mm_get_hmm(mm);
-	bool cleanup = false;
+	struct hmm *hmm;
 
-	if (hmm)
-		return hmm;
+	lockdep_assert_held_exclusive(&mm->mmap_sem);
+
+	/* Abuse the page_table_lock to also protect mm->hmm. */
+	spin_lock(&mm->page_table_lock);
+	hmm = mm->hmm;
+	if (mm->hmm && kref_get_unless_zero(&mm->hmm->kref))
+		goto out_unlock;
+	spin_unlock(&mm->page_table_lock);
 
 	hmm = kmalloc(sizeof(*hmm), GFP_KERNEL);
 	if (!hmm)
@@ -68,55 +64,50 @@ static struct hmm *hmm_get_or_create(struct mm_struct *mm)
 	init_rwsem(&hmm->mirrors_sem);
 	hmm->mmu_notifier.ops = NULL;
 	INIT_LIST_HEAD(&hmm->ranges);
-	mutex_init(&hmm->lock);
+	spin_lock_init(&hmm->ranges_lock);
 	kref_init(&hmm->kref);
 	hmm->notifiers = 0;
-	hmm->dead = false;
 	hmm->mm = mm;
 
-	spin_lock(&mm->page_table_lock);
-	if (!mm->hmm)
-		mm->hmm = hmm;
-	else
-		cleanup = true;
-	spin_unlock(&mm->page_table_lock);
+	hmm->mmu_notifier.ops = &hmm_mmu_notifier_ops;
+	if (__mmu_notifier_register(&hmm->mmu_notifier, mm)) {
+		kfree(hmm);
+		return NULL;
+	}
 
-	if (cleanup)
-		goto error;
+	mmgrab(hmm->mm);
 
 	/*
-	 * We should only get here if hold the mmap_sem in write mode ie on
-	 * registration of first mirror through hmm_mirror_register()
+	 * We hold the exclusive mmap_sem here so we know that mm->hmm is
+	 * still NULL or 0 kref, and is safe to update.
 	 */
-	hmm->mmu_notifier.ops = &hmm_mmu_notifier_ops;
-	if (__mmu_notifier_register(&hmm->mmu_notifier, mm))
-		goto error_mm;
-
-	return hmm;
-
-error_mm:
 	spin_lock(&mm->page_table_lock);
-	if (mm->hmm == hmm)
-		mm->hmm = NULL;
+	mm->hmm = hmm;
+
+out_unlock:
 	spin_unlock(&mm->page_table_lock);
-error:
+	return hmm;
+}
+
+static void hmm_free_rcu(struct rcu_head *rcu)
+{
+	struct hmm *hmm = container_of(rcu, struct hmm, rcu);
+
+	mmdrop(hmm->mm);
 	kfree(hmm);
-	return NULL;
 }
 
 static void hmm_free(struct kref *kref)
 {
 	struct hmm *hmm = container_of(kref, struct hmm, kref);
-	struct mm_struct *mm = hmm->mm;
 
-	mmu_notifier_unregister_no_release(&hmm->mmu_notifier, mm);
+	spin_lock(&hmm->mm->page_table_lock);
+	if (hmm->mm->hmm == hmm)
+		hmm->mm->hmm = NULL;
+	spin_unlock(&hmm->mm->page_table_lock);
 
-	spin_lock(&mm->page_table_lock);
-	if (mm->hmm == hmm)
-		mm->hmm = NULL;
-	spin_unlock(&mm->page_table_lock);
-
-	kfree(hmm);
+	mmu_notifier_unregister_no_release(&hmm->mmu_notifier, hmm->mm);
+	mmu_notifier_call_srcu(&hmm->rcu, hmm_free_rcu);
 }
 
 static inline void hmm_put(struct hmm *hmm)
@@ -124,126 +115,40 @@ static inline void hmm_put(struct hmm *hmm)
 	kref_put(&hmm->kref, hmm_free);
 }
 
-void hmm_mm_destroy(struct mm_struct *mm)
-{
-	struct hmm *hmm;
-
-	spin_lock(&mm->page_table_lock);
-	hmm = mm_get_hmm(mm);
-	mm->hmm = NULL;
-	if (hmm) {
-		hmm->mm = NULL;
-		hmm->dead = true;
-		spin_unlock(&mm->page_table_lock);
-		hmm_put(hmm);
-		return;
-	}
-
-	spin_unlock(&mm->page_table_lock);
-}
-
 static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 {
-	struct hmm *hmm = mm_get_hmm(mm);
+	struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
 	struct hmm_mirror *mirror;
-	struct hmm_range *range;
 
-	/* Report this HMM as dying. */
-	hmm->dead = true;
+	/* Bail out if hmm is in the process of being freed */
+	if (!kref_get_unless_zero(&hmm->kref))
+		return;
 
-	/* Wake-up everyone waiting on any range. */
-	mutex_lock(&hmm->lock);
-	list_for_each_entry(range, &hmm->ranges, list) {
-		range->valid = false;
-	}
-	wake_up_all(&hmm->wq);
-	mutex_unlock(&hmm->lock);
+	/*
+	 * Since hmm_range_register() holds the mmget() lock hmm_release() is
+	 * prevented as long as a range exists.
+	 */
+	WARN_ON(!list_empty_careful(&hmm->ranges));
 
-	down_write(&hmm->mirrors_sem);
-	mirror = list_first_entry_or_null(&hmm->mirrors, struct hmm_mirror,
-					  list);
-	while (mirror) {
-		list_del_init(&mirror->list);
-		if (mirror->ops->release) {
-			/*
-			 * Drop mirrors_sem so callback can wait on any pending
-			 * work that might itself trigger mmu_notifier callback
-			 * and thus would deadlock with us.
-			 */
-			up_write(&hmm->mirrors_sem);
-			mirror->ops->release(mirror);
-			down_write(&hmm->mirrors_sem);
-		}
-		mirror = list_first_entry_or_null(&hmm->mirrors,
-						  struct hmm_mirror, list);
-	}
-	up_write(&hmm->mirrors_sem);
-
-	hmm_put(hmm);
-}
-
-static int hmm_invalidate_range_start(struct mmu_notifier *mn,
-			const struct mmu_notifier_range *nrange)
-{
-	struct hmm *hmm = mm_get_hmm(nrange->mm);
-	struct hmm_mirror *mirror;
-	struct hmm_update update;
-	struct hmm_range *range;
-	int ret = 0;
-
-	VM_BUG_ON(!hmm);
-
-	update.start = nrange->start;
-	update.end = nrange->end;
-	update.event = HMM_UPDATE_INVALIDATE;
-	update.blockable = mmu_notifier_range_blockable(nrange);
-
-	if (mmu_notifier_range_blockable(nrange))
-		mutex_lock(&hmm->lock);
-	else if (!mutex_trylock(&hmm->lock)) {
-		ret = -EAGAIN;
-		goto out;
-	}
-	hmm->notifiers++;
-	list_for_each_entry(range, &hmm->ranges, list) {
-		if (update.end < range->start || update.start >= range->end)
-			continue;
-
-		range->valid = false;
-	}
-	mutex_unlock(&hmm->lock);
-
-	if (mmu_notifier_range_blockable(nrange))
-		down_read(&hmm->mirrors_sem);
-	else if (!down_read_trylock(&hmm->mirrors_sem)) {
-		ret = -EAGAIN;
-		goto out;
-	}
+	down_read(&hmm->mirrors_sem);
 	list_for_each_entry(mirror, &hmm->mirrors, list) {
-		int ret;
-
-		ret = mirror->ops->sync_cpu_device_pagetables(mirror, &update);
-		if (!update.blockable && ret == -EAGAIN) {
-			up_read(&hmm->mirrors_sem);
-			ret = -EAGAIN;
-			goto out;
-		}
+		/*
+		 * Note: The driver is not allowed to trigger
+		 * hmm_mirror_unregister() from this thread.
+		 */
+		if (mirror->ops->release)
+			mirror->ops->release(mirror);
 	}
 	up_read(&hmm->mirrors_sem);
 
-out:
 	hmm_put(hmm);
-	return ret;
 }
 
-static void hmm_invalidate_range_end(struct mmu_notifier *mn,
-			const struct mmu_notifier_range *nrange)
+static void notifiers_decrement(struct hmm *hmm)
 {
-	struct hmm *hmm = mm_get_hmm(nrange->mm);
+	unsigned long flags;
 
-	VM_BUG_ON(!hmm);
-
-	mutex_lock(&hmm->lock);
+	spin_lock_irqsave(&hmm->ranges_lock, flags);
 	hmm->notifiers--;
 	if (!hmm->notifiers) {
 		struct hmm_range *range;
@@ -255,8 +160,73 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn,
 		}
 		wake_up_all(&hmm->wq);
 	}
-	mutex_unlock(&hmm->lock);
+	spin_unlock_irqrestore(&hmm->ranges_lock, flags);
+}
 
+static int hmm_invalidate_range_start(struct mmu_notifier *mn,
+			const struct mmu_notifier_range *nrange)
+{
+	struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
+	struct hmm_mirror *mirror;
+	struct hmm_update update;
+	struct hmm_range *range;
+	unsigned long flags;
+	int ret = 0;
+
+	if (!kref_get_unless_zero(&hmm->kref))
+		return 0;
+
+	update.start = nrange->start;
+	update.end = nrange->end;
+	update.event = HMM_UPDATE_INVALIDATE;
+	update.blockable = mmu_notifier_range_blockable(nrange);
+
+	spin_lock_irqsave(&hmm->ranges_lock, flags);
+	hmm->notifiers++;
+	list_for_each_entry(range, &hmm->ranges, list) {
+		if (update.end < range->start || update.start >= range->end)
+			continue;
+
+		range->valid = false;
+	}
+	spin_unlock_irqrestore(&hmm->ranges_lock, flags);
+
+	if (mmu_notifier_range_blockable(nrange))
+		down_read(&hmm->mirrors_sem);
+	else if (!down_read_trylock(&hmm->mirrors_sem)) {
+		ret = -EAGAIN;
+		goto out;
+	}
+
+	list_for_each_entry(mirror, &hmm->mirrors, list) {
+		int rc;
+
+		rc = mirror->ops->sync_cpu_device_pagetables(mirror, &update);
+		if (rc) {
+			if (WARN_ON(update.blockable || rc != -EAGAIN))
+				continue;
+			ret = -EAGAIN;
+			break;
+		}
+	}
+	up_read(&hmm->mirrors_sem);
+
+out:
+	if (ret)
+		notifiers_decrement(hmm);
+	hmm_put(hmm);
+	return ret;
+}
+
+static void hmm_invalidate_range_end(struct mmu_notifier *mn,
+			const struct mmu_notifier_range *nrange)
+{
+	struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
+
+	if (!kref_get_unless_zero(&hmm->kref))
+		return;
+
+	notifiers_decrement(hmm);
 	hmm_put(hmm);
 }
 
@@ -271,14 +241,15 @@ static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
  *
  * @mirror: new mirror struct to register
  * @mm: mm to register against
+ * Return: 0 on success, -ENOMEM if no memory, -EINVAL if invalid arguments
  *
  * To start mirroring a process address space, the device driver must register
  * an HMM mirror struct.
- *
- * THE mm->mmap_sem MUST BE HELD IN WRITE MODE !
  */
 int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm)
 {
+	lockdep_assert_held_exclusive(&mm->mmap_sem);
+
 	/* Sanity check */
 	if (!mm || !mirror || !mirror->ops)
 		return -EINVAL;
@@ -298,23 +269,17 @@ EXPORT_SYMBOL(hmm_mirror_register);
 /*
  * hmm_mirror_unregister() - unregister a mirror
  *
- * @mirror: new mirror struct to register
+ * @mirror: mirror struct to unregister
  *
  * Stop mirroring a process address space, and cleanup.
  */
 void hmm_mirror_unregister(struct hmm_mirror *mirror)
 {
-	struct hmm *hmm = READ_ONCE(mirror->hmm);
-
-	if (hmm == NULL)
-		return;
+	struct hmm *hmm = mirror->hmm;
 
 	down_write(&hmm->mirrors_sem);
-	list_del_init(&mirror->list);
-	/* To protect us against double unregister ... */
-	mirror->hmm = NULL;
+	list_del(&mirror->list);
 	up_write(&hmm->mirrors_sem);
-
 	hmm_put(hmm);
 }
 EXPORT_SYMBOL(hmm_mirror_unregister);
@@ -330,7 +295,7 @@ struct hmm_vma_walk {
 static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
 			    bool write_fault, uint64_t *pfn)
 {
-	unsigned int flags = FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_REMOTE;
+	unsigned int flags = FAULT_FLAG_REMOTE;
 	struct hmm_vma_walk *hmm_vma_walk = walk->private;
 	struct hmm_range *range = hmm_vma_walk->range;
 	struct vm_area_struct *vma = walk->vma;
@@ -372,7 +337,7 @@ static int hmm_pfns_bad(unsigned long addr,
  * @fault: should we fault or not ?
  * @write_fault: write fault ?
  * @walk: mm_walk structure
- * Returns: 0 on success, -EBUSY after page fault, or page fault error
+ * Return: 0 on success, -EBUSY after page fault, or page fault error
  *
  * This function will be called whenever pmd_none() or pte_none() returns true,
  * or whenever there is no page directory covering the virtual address range.
@@ -550,7 +515,7 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk,
 
 static inline uint64_t pte_to_hmm_pfn_flags(struct hmm_range *range, pte_t pte)
 {
-	if (pte_none(pte) || !pte_present(pte))
+	if (pte_none(pte) || !pte_present(pte) || pte_protnone(pte))
 		return 0;
 	return pte_write(pte) ? range->flags[HMM_PFN_VALID] |
 				range->flags[HMM_PFN_WRITE] :
@@ -788,7 +753,6 @@ static int hmm_vma_walk_pud(pud_t *pudp,
 			return hmm_vma_walk_hole_(addr, end, fault,
 						write_fault, walk);
 
-#ifdef CONFIG_HUGETLB_PAGE
 		pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
 		for (i = 0; i < npages; ++i, ++pfn) {
 			hmm_vma_walk->pgmap = get_dev_pagemap(pfn,
@@ -804,9 +768,6 @@ static int hmm_vma_walk_pud(pud_t *pudp,
 		}
 		hmm_vma_walk->last = end;
 		return 0;
-#else
-		return -EINVAL;
-#endif
 	}
 
 	split_huge_pud(walk->vma, pudp, addr);
@@ -909,12 +870,14 @@ static void hmm_pfns_clear(struct hmm_range *range,
  * Track updates to the CPU page table see include/linux/hmm.h
  */
 int hmm_range_register(struct hmm_range *range,
-		       struct mm_struct *mm,
+		       struct hmm_mirror *mirror,
 		       unsigned long start,
 		       unsigned long end,
 		       unsigned page_shift)
 {
 	unsigned long mask = ((1UL << page_shift) - 1UL);
+	struct hmm *hmm = mirror->hmm;
+	unsigned long flags;
 
 	range->valid = false;
 	range->hmm = NULL;
@@ -928,28 +891,24 @@ int hmm_range_register(struct hmm_range *range,
 	range->start = start;
 	range->end = end;
 
-	range->hmm = hmm_get_or_create(mm);
-	if (!range->hmm)
+	/* Prevent hmm_release() from running while the range is valid */
+	if (!mmget_not_zero(hmm->mm))
 		return -EFAULT;
 
-	/* Check if hmm_mm_destroy() was call. */
-	if (range->hmm->mm == NULL || range->hmm->dead) {
-		hmm_put(range->hmm);
-		return -EFAULT;
-	}
+	/* Initialize range to track CPU page table updates. */
+	spin_lock_irqsave(&hmm->ranges_lock, flags);
 
-	/* Initialize range to track CPU page table update */
-	mutex_lock(&range->hmm->lock);
-
-	list_add_rcu(&range->list, &range->hmm->ranges);
+	range->hmm = hmm;
+	kref_get(&hmm->kref);
+	list_add(&range->list, &hmm->ranges);
 
 	/*
 	 * If there are any concurrent notifiers we have to wait for them for
 	 * the range to be valid (see hmm_range_wait_until_valid()).
 	 */
-	if (!range->hmm->notifiers)
+	if (!hmm->notifiers)
 		range->valid = true;
-	mutex_unlock(&range->hmm->lock);
+	spin_unlock_irqrestore(&hmm->ranges_lock, flags);
 
 	return 0;
 }
@@ -964,25 +923,31 @@ EXPORT_SYMBOL(hmm_range_register);
  */
 void hmm_range_unregister(struct hmm_range *range)
 {
-	/* Sanity check this really should not happen. */
-	if (range->hmm == NULL || range->end <= range->start)
-		return;
+	struct hmm *hmm = range->hmm;
+	unsigned long flags;
 
-	mutex_lock(&range->hmm->lock);
-	list_del_rcu(&range->list);
-	mutex_unlock(&range->hmm->lock);
+	spin_lock_irqsave(&hmm->ranges_lock, flags);
+	list_del_init(&range->list);
+	spin_unlock_irqrestore(&hmm->ranges_lock, flags);
 
 	/* Drop reference taken by hmm_range_register() */
+	mmput(hmm->mm);
+	hmm_put(hmm);
+
+	/*
+	 * The range is now invalid and the ref on the hmm is dropped, so
+	 * poison the pointer.  Leave other fields in place, for the caller's
+	 * use.
+	 */
 	range->valid = false;
-	hmm_put(range->hmm);
-	range->hmm = NULL;
+	memset(&range->hmm, POISON_INUSE, sizeof(range->hmm));
 }
 EXPORT_SYMBOL(hmm_range_unregister);
 
 /*
  * hmm_range_snapshot() - snapshot CPU page table for a range
  * @range: range
- * Returns: -EINVAL if invalid argument, -ENOMEM out of memory, -EPERM invalid
+ * Return: -EINVAL if invalid argument, -ENOMEM out of memory, -EPERM invalid
  *          permission (for instance asking for write and range is read only),
  *          -EAGAIN if you need to retry, -EFAULT invalid (ie either no valid
  *          vma or it is illegal to access that range), number of valid pages
@@ -1001,10 +966,7 @@ long hmm_range_snapshot(struct hmm_range *range)
 	struct vm_area_struct *vma;
 	struct mm_walk mm_walk;
 
-	/* Check if hmm_mm_destroy() was call. */
-	if (hmm->mm == NULL || hmm->dead)
-		return -EFAULT;
-
+	lockdep_assert_held(&hmm->mm->mmap_sem);
 	do {
 		/* If range is no longer valid force retry. */
 		if (!range->valid)
@@ -1015,9 +977,8 @@ long hmm_range_snapshot(struct hmm_range *range)
 			return -EFAULT;
 
 		if (is_vm_hugetlb_page(vma)) {
-			struct hstate *h = hstate_vma(vma);
-
-			if (huge_page_shift(h) != range->page_shift &&
+			if (huge_page_shift(hstate_vma(vma)) !=
+				    range->page_shift &&
 			    range->page_shift != PAGE_SHIFT)
 				return -EINVAL;
 		} else {
@@ -1066,7 +1027,7 @@ EXPORT_SYMBOL(hmm_range_snapshot);
  * hmm_range_fault() - try to fault some address in a virtual address range
  * @range: range being faulted
  * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem)
- * Returns: number of valid pages in range->pfns[] (from range start
+ * Return: number of valid pages in range->pfns[] (from range start
  *          address). This may be zero. If the return value is negative,
  *          then one of the following values may be returned:
  *
@@ -1100,9 +1061,7 @@ long hmm_range_fault(struct hmm_range *range, bool block)
 	struct mm_walk mm_walk;
 	int ret;
 
-	/* Check if hmm_mm_destroy() was call. */
-	if (hmm->mm == NULL || hmm->dead)
-		return -EFAULT;
+	lockdep_assert_held(&hmm->mm->mmap_sem);
 
 	do {
 		/* If range is no longer valid force retry. */
@@ -1184,7 +1143,7 @@ EXPORT_SYMBOL(hmm_range_fault);
  * @device: device against to dma map page to
  * @daddrs: dma address of mapped pages
  * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem)
- * Returns: number of pages mapped on success, -EAGAIN if mmap_sem have been
+ * Return: number of pages mapped on success, -EAGAIN if mmap_sem have been
  *          drop and you need to try again, some other error value otherwise
  *
  * Note same usage pattern as hmm_range_fault().
@@ -1272,7 +1231,7 @@ EXPORT_SYMBOL(hmm_range_dma_map);
  * @device: device against which dma map was done
  * @daddrs: dma address of mapped pages
  * @dirty: dirty page if it had the write flag set
- * Returns: number of page unmapped on success, -EINVAL otherwise
+ * Return: number of page unmapped on success, -EINVAL otherwise
  *
  * Note that caller MUST abide by mmu notifier or use HMM mirror and abide
  * to the sync_cpu_device_pagetables() callback so that it is safe here to
@@ -1394,7 +1353,7 @@ static void hmm_devmem_free(struct page *page, void *data)
  * @ops: memory event device driver callback (see struct hmm_devmem_ops)
  * @device: device struct to bind the resource too
  * @size: size in bytes of the device memory to add
- * Returns: pointer to new hmm_devmem struct ERR_PTR otherwise
+ * Return: pointer to new hmm_devmem struct ERR_PTR otherwise
  *
  * This function first finds an empty range of physical address big enough to
  * contain the new resource, and then hotplugs it as ZONE_DEVICE memory, which
diff --git a/mm/swap.c b/mm/swap.c
index 7ede3ed..607c482 100644
--- a/mm/swap.c
+++ b/mm/swap.c
@@ -740,15 +740,20 @@ void release_pages(struct page **pages, int nr)
 		if (is_huge_zero_page(page))
 			continue;
 
-		/* Device public page can not be huge page */
-		if (is_device_public_page(page)) {
+		if (is_zone_device_page(page)) {
 			if (locked_pgdat) {
 				spin_unlock_irqrestore(&locked_pgdat->lru_lock,
 						       flags);
 				locked_pgdat = NULL;
 			}
-			put_devmap_managed_page(page);
-			continue;
+			/*
+			 * ZONE_DEVICE pages that return 'false' from
+			 * put_devmap_managed_page() do not require special
+			 * processing, and instead, expect a call to
+			 * put_page_testzero().
+			 */
+			if (put_devmap_managed_page(page))
+				continue;
 		}
 
 		page = compound_head(page);