x86/boot/64: Add support of additional page table level during early boot

This patch adds support for 5-level paging during early boot.
It generalizes boot for 4- and 5-level paging on 64-bit systems with
compile-time switch between them.

Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
Cc: Andrew Morton <akpm@linux-foundation.org>
Cc: Andy Lutomirski <luto@amacapital.net>
Cc: Andy Lutomirski <luto@kernel.org>
Cc: Borislav Petkov <bp@alien8.de>
Cc: Brian Gerst <brgerst@gmail.com>
Cc: Dave Hansen <dave.hansen@intel.com>
Cc: Denys Vlasenko <dvlasenk@redhat.com>
Cc: H. Peter Anvin <hpa@zytor.com>
Cc: Josh Poimboeuf <jpoimboe@redhat.com>
Cc: Linus Torvalds <torvalds@linux-foundation.org>
Cc: Peter Zijlstra <peterz@infradead.org>
Cc: Thomas Gleixner <tglx@linutronix.de>
Cc: linux-arch@vger.kernel.org
Cc: linux-mm@kvack.org
Link: http://lkml.kernel.org/r/20170606113133.22974-10-kirill.shutemov@linux.intel.com
Signed-off-by: Ingo Molnar <mingo@kernel.org>
diff --git a/arch/x86/kernel/head64.c b/arch/x86/kernel/head64.c
index 71ca01b..2b2ac38 100644
--- a/arch/x86/kernel/head64.c
+++ b/arch/x86/kernel/head64.c
@@ -47,6 +47,7 @@ void __init __startup_64(unsigned long physaddr)
 {
 	unsigned long load_delta, *p;
 	pgdval_t *pgd;
+	p4dval_t *p4d;
 	pudval_t *pud;
 	pmdval_t *pmd, pmd_entry;
 	int i;
@@ -70,6 +71,11 @@ void __init __startup_64(unsigned long physaddr)
 	pgd = fixup_pointer(&early_top_pgt, physaddr);
 	pgd[pgd_index(__START_KERNEL_map)] += load_delta;
 
+	if (IS_ENABLED(CONFIG_X86_5LEVEL)) {
+		p4d = fixup_pointer(&level4_kernel_pgt, physaddr);
+		p4d[511] += load_delta;
+	}
+
 	pud = fixup_pointer(&level3_kernel_pgt, physaddr);
 	pud[510] += load_delta;
 	pud[511] += load_delta;
@@ -87,9 +93,21 @@ void __init __startup_64(unsigned long physaddr)
 	pud = fixup_pointer(early_dynamic_pgts[next_early_pgt++], physaddr);
 	pmd = fixup_pointer(early_dynamic_pgts[next_early_pgt++], physaddr);
 
-	i = (physaddr >> PGDIR_SHIFT) % PTRS_PER_PGD;
-	pgd[i + 0] = (pgdval_t)pud + _KERNPG_TABLE;
-	pgd[i + 1] = (pgdval_t)pud + _KERNPG_TABLE;
+	if (IS_ENABLED(CONFIG_X86_5LEVEL)) {
+		p4d = fixup_pointer(early_dynamic_pgts[next_early_pgt++], physaddr);
+
+		i = (physaddr >> PGDIR_SHIFT) % PTRS_PER_PGD;
+		pgd[i + 0] = (pgdval_t)p4d + _KERNPG_TABLE;
+		pgd[i + 1] = (pgdval_t)p4d + _KERNPG_TABLE;
+
+		i = (physaddr >> P4D_SHIFT) % PTRS_PER_P4D;
+		p4d[i + 0] = (pgdval_t)pud + _KERNPG_TABLE;
+		p4d[i + 1] = (pgdval_t)pud + _KERNPG_TABLE;
+	} else {
+		i = (physaddr >> PGDIR_SHIFT) % PTRS_PER_PGD;
+		pgd[i + 0] = (pgdval_t)pud + _KERNPG_TABLE;
+		pgd[i + 1] = (pgdval_t)pud + _KERNPG_TABLE;
+	}
 
 	i = (physaddr >> PUD_SHIFT) % PTRS_PER_PUD;
 	pud[i + 0] = (pudval_t)pmd + _KERNPG_TABLE;
@@ -134,6 +152,7 @@ int __init early_make_pgtable(unsigned long address)
 {
 	unsigned long physaddr = address - __PAGE_OFFSET;
 	pgdval_t pgd, *pgd_p;
+	p4dval_t p4d, *p4d_p;
 	pudval_t pud, *pud_p;
 	pmdval_t pmd, *pmd_p;
 
@@ -150,8 +169,25 @@ int __init early_make_pgtable(unsigned long address)
 	 * critical -- __PAGE_OFFSET would point us back into the dynamic
 	 * range and we might end up looping forever...
 	 */
-	if (pgd)
-		pud_p = (pudval_t *)((pgd & PTE_PFN_MASK) + __START_KERNEL_map - phys_base);
+	if (!IS_ENABLED(CONFIG_X86_5LEVEL))
+		p4d_p = pgd_p;
+	else if (pgd)
+		p4d_p = (p4dval_t *)((pgd & PTE_PFN_MASK) + __START_KERNEL_map - phys_base);
+	else {
+		if (next_early_pgt >= EARLY_DYNAMIC_PAGE_TABLES) {
+			reset_early_page_tables();
+			goto again;
+		}
+
+		p4d_p = (p4dval_t *)early_dynamic_pgts[next_early_pgt++];
+		memset(p4d_p, 0, sizeof(*p4d_p) * PTRS_PER_P4D);
+		*pgd_p = (pgdval_t)p4d_p - __START_KERNEL_map + phys_base + _KERNPG_TABLE;
+	}
+	p4d_p += p4d_index(address);
+	p4d = *p4d_p;
+
+	if (p4d)
+		pud_p = (pudval_t *)((p4d & PTE_PFN_MASK) + __START_KERNEL_map - phys_base);
 	else {
 		if (next_early_pgt >= EARLY_DYNAMIC_PAGE_TABLES) {
 			reset_early_page_tables();
@@ -160,7 +196,7 @@ int __init early_make_pgtable(unsigned long address)
 
 		pud_p = (pudval_t *)early_dynamic_pgts[next_early_pgt++];
 		memset(pud_p, 0, sizeof(*pud_p) * PTRS_PER_PUD);
-		*pgd_p = (pgdval_t)pud_p - __START_KERNEL_map + phys_base + _KERNPG_TABLE;
+		*p4d_p = (p4dval_t)pud_p - __START_KERNEL_map + phys_base + _KERNPG_TABLE;
 	}
 	pud_p += pud_index(address);
 	pud = *pud_p;