x86: cpa: fix the self-test

Signed-off-by: Ingo Molnar <mingo@elte.hu>
Signed-off-by: Thomas Gleixner <tglx@linutronix.de>
diff --git a/arch/x86/mm/pageattr.c b/arch/x86/mm/pageattr.c
index 97ec9e7..532a40b 100644
--- a/arch/x86/mm/pageattr.c
+++ b/arch/x86/mm/pageattr.c
@@ -197,10 +197,11 @@
 	unsigned long addr;
 	pte_t *pbase, *tmp;
 	struct page *base;
-	int i, level;
+	unsigned int i, level;
 
 #ifdef CONFIG_DEBUG_PAGEALLOC
-	gfp_flags = GFP_ATOMIC;
+	gfp_flags = __GFP_HIGH | __GFP_NOFAIL | __GFP_NOWARN;
+	gfp_flags = GFP_ATOMIC | __GFP_NOWARN;
 #endif
 	base = alloc_pages(gfp_flags, 0);
 	if (!base)
@@ -224,6 +225,7 @@
 	paravirt_alloc_pt(&init_mm, page_to_pfn(base));
 #endif
 
+	pgprot_val(ref_prot) &= ~_PAGE_NX;
 	for (i = 0; i < PTRS_PER_PTE; i++, addr += PAGE_SIZE)
 		set_pte(&pbase[i], pfn_pte(addr >> PAGE_SHIFT, ref_prot));
 
@@ -248,7 +250,8 @@
 }
 
 static int
-__change_page_attr(unsigned long address, unsigned long pfn, pgprot_t prot)
+__change_page_attr(unsigned long address, unsigned long pfn,
+		   pgprot_t mask_set, pgprot_t mask_clr)
 {
 	struct page *kpte_page;
 	int level, err = 0;
@@ -267,15 +270,20 @@
 	BUG_ON(PageLRU(kpte_page));
 	BUG_ON(PageCompound(kpte_page));
 
-	prot = static_protections(prot, address);
-
 	if (level == PG_LEVEL_4K) {
-		WARN_ON_ONCE(pgprot_val(prot) & _PAGE_PSE);
-		set_pte_atomic(kpte, pfn_pte(pfn, canon_pgprot(prot)));
-	} else {
-		/* Clear the PSE bit for the 4k level pages ! */
-		pgprot_val(prot) = pgprot_val(prot) & ~_PAGE_PSE;
+		pgprot_t new_prot = pte_pgprot(*kpte);
+		pte_t new_pte, old_pte = *kpte;
 
+		pgprot_val(new_prot) &= ~pgprot_val(mask_clr);
+		pgprot_val(new_prot) |= pgprot_val(mask_set);
+
+		new_prot = static_protections(new_prot, address);
+
+		new_pte = pfn_pte(pfn, canon_pgprot(new_prot));
+		BUG_ON(pte_pfn(new_pte) != pte_pfn(old_pte));
+
+		set_pte_atomic(kpte, new_pte);
+	} else {
 		err = split_large_page(kpte, address);
 		if (!err)
 			goto repeat;
@@ -297,22 +305,26 @@
  * Modules and drivers should use the set_memory_* APIs instead.
  */
 
-static int change_page_attr_addr(unsigned long address, pgprot_t prot)
+static int
+change_page_attr_addr(unsigned long address, pgprot_t mask_set,
+							pgprot_t mask_clr)
 {
 	int err = 0, kernel_map = 0;
-	unsigned long pfn = __pa(address) >> PAGE_SHIFT;
+	unsigned long pfn;
 
 #ifdef CONFIG_X86_64
 	if (address >= __START_KERNEL_map &&
 			address < __START_KERNEL_map + KERNEL_TEXT_SIZE) {
 
-		address = (unsigned long)__va(__pa(address));
+		address = (unsigned long)__va(__pa((void *)address));
 		kernel_map = 1;
 	}
 #endif
 
-	if (!kernel_map || pte_present(pfn_pte(0, prot))) {
-		err = __change_page_attr(address, pfn, prot);
+	pfn = __pa(address) >> PAGE_SHIFT;
+
+	if (!kernel_map || 1) {
+		err = __change_page_attr(address, pfn, mask_set, mask_clr);
 		if (err)
 			return err;
 	}
@@ -324,12 +336,15 @@
 	 */
 	if (__pa(address) < KERNEL_TEXT_SIZE) {
 		unsigned long addr2;
-		pgprot_t prot2;
 
-		addr2 = __START_KERNEL_map + __pa(address);
+		addr2 = __pa(address) + __START_KERNEL_map - phys_base;
 		/* Make sure the kernel mappings stay executable */
-		prot2 = pte_pgprot(pte_mkexec(pfn_pte(0, prot)));
-		err = __change_page_attr(addr2, pfn, prot2);
+		pgprot_val(mask_clr) |= _PAGE_NX;
+		/*
+		 * Our high aliases are imprecise, so do not propagate
+		 * failures back to users:
+		 */
+		__change_page_attr(addr2, pfn, mask_set, mask_clr);
 	}
 #endif
 
@@ -339,26 +354,13 @@
 static int __change_page_attr_set_clr(unsigned long addr, int numpages,
 				      pgprot_t mask_set, pgprot_t mask_clr)
 {
-	pgprot_t new_prot;
-	int level;
-	pte_t *pte;
-	int i, ret;
+	unsigned int i;
+	int ret;
 
-	for (i = 0; i < numpages ; i++) {
-
-		pte = lookup_address(addr, &level);
-		if (!pte)
-			return -EINVAL;
-
-		new_prot = pte_pgprot(*pte);
-
-		pgprot_val(new_prot) &= ~pgprot_val(mask_clr);
-		pgprot_val(new_prot) |= pgprot_val(mask_set);
-
-		ret = change_page_attr_addr(addr, new_prot);
+	for (i = 0; i < numpages ; i++, addr += PAGE_SIZE) {
+		ret = change_page_attr_addr(addr, mask_set, mask_clr);
 		if (ret)
 			return ret;
-		addr += PAGE_SIZE;
 	}
 
 	return 0;