KVM: MMU: Do not unconditionally read PDPTE from guest memory

Architecturally, PDPTEs are cached in the PDPTRs when CR3 is reloaded.
On SVM, it is not possible to implement this, but on VMX this is possible
and was indeed implemented until nested SVM changed this to unconditionally
read PDPTEs dynamically.  This has noticable impact when running PAE guests.

Fix by changing the MMU to read PDPTRs from the cache, falling back to
reading from memory for the nested MMU.

Signed-off-by: Avi Kivity <avi@redhat.com>
Tested-by: Joerg Roedel <joerg.roedel@amd.com>
Signed-off-by: Marcelo Tosatti <mtosatti@redhat.com>
diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index 307e3cf..b31a341 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -265,6 +265,7 @@
 	void (*new_cr3)(struct kvm_vcpu *vcpu);
 	void (*set_cr3)(struct kvm_vcpu *vcpu, unsigned long root);
 	unsigned long (*get_cr3)(struct kvm_vcpu *vcpu);
+	u64 (*get_pdptr)(struct kvm_vcpu *vcpu, int index);
 	int (*page_fault)(struct kvm_vcpu *vcpu, gva_t gva, u32 err,
 			  bool prefault);
 	void (*inject_page_fault)(struct kvm_vcpu *vcpu,
diff --git a/arch/x86/kvm/kvm_cache_regs.h b/arch/x86/kvm/kvm_cache_regs.h
index 3377d53..544076c 100644
--- a/arch/x86/kvm/kvm_cache_regs.h
+++ b/arch/x86/kvm/kvm_cache_regs.h
@@ -45,13 +45,6 @@
 	return vcpu->arch.walk_mmu->pdptrs[index];
 }
 
-static inline u64 kvm_pdptr_read_mmu(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu, int index)
-{
-	load_pdptrs(vcpu, mmu, mmu->get_cr3(vcpu));
-
-	return mmu->pdptrs[index];
-}
-
 static inline ulong kvm_read_cr0_bits(struct kvm_vcpu *vcpu, ulong mask)
 {
 	ulong tmask = mask & KVM_POSSIBLE_CR0_GUEST_BITS;
diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index 8e8da79..f1b36cf 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -2770,7 +2770,7 @@
 
 		ASSERT(!VALID_PAGE(root));
 		if (vcpu->arch.mmu.root_level == PT32E_ROOT_LEVEL) {
-			pdptr = kvm_pdptr_read_mmu(vcpu, &vcpu->arch.mmu, i);
+			pdptr = vcpu->arch.mmu.get_pdptr(vcpu, i);
 			if (!is_present_gpte(pdptr)) {
 				vcpu->arch.mmu.pae_root[i] = 0;
 				continue;
@@ -3318,6 +3318,7 @@
 	context->direct_map = true;
 	context->set_cr3 = kvm_x86_ops->set_tdp_cr3;
 	context->get_cr3 = get_cr3;
+	context->get_pdptr = kvm_pdptr_read;
 	context->inject_page_fault = kvm_inject_page_fault;
 	context->nx = is_nx(vcpu);
 
@@ -3376,6 +3377,7 @@
 
 	vcpu->arch.walk_mmu->set_cr3           = kvm_x86_ops->set_cr3;
 	vcpu->arch.walk_mmu->get_cr3           = get_cr3;
+	vcpu->arch.walk_mmu->get_pdptr         = kvm_pdptr_read;
 	vcpu->arch.walk_mmu->inject_page_fault = kvm_inject_page_fault;
 
 	return r;
@@ -3386,6 +3388,7 @@
 	struct kvm_mmu *g_context = &vcpu->arch.nested_mmu;
 
 	g_context->get_cr3           = get_cr3;
+	g_context->get_pdptr         = kvm_pdptr_read;
 	g_context->inject_page_fault = kvm_inject_page_fault;
 
 	/*
diff --git a/arch/x86/kvm/paging_tmpl.h b/arch/x86/kvm/paging_tmpl.h
index 507e2b8..f6dd9fe 100644
--- a/arch/x86/kvm/paging_tmpl.h
+++ b/arch/x86/kvm/paging_tmpl.h
@@ -163,7 +163,7 @@
 
 #if PTTYPE == 64
 	if (walker->level == PT32E_ROOT_LEVEL) {
-		pte = kvm_pdptr_read_mmu(vcpu, mmu, (addr >> 30) & 3);
+		pte = mmu->get_pdptr(vcpu, (addr >> 30) & 3);
 		trace_kvm_mmu_paging_element(pte, walker->level);
 		if (!is_present_gpte(pte))
 			goto error;
diff --git a/arch/x86/kvm/svm.c b/arch/x86/kvm/svm.c
index 2b24a88..f043168 100644
--- a/arch/x86/kvm/svm.c
+++ b/arch/x86/kvm/svm.c
@@ -1844,6 +1844,20 @@
 	return svm->nested.nested_cr3;
 }
 
+static u64 nested_svm_get_tdp_pdptr(struct kvm_vcpu *vcpu, int index)
+{
+	struct vcpu_svm *svm = to_svm(vcpu);
+	u64 cr3 = svm->nested.nested_cr3;
+	u64 pdpte;
+	int ret;
+
+	ret = kvm_read_guest_page(vcpu->kvm, gpa_to_gfn(cr3), &pdpte,
+				  offset_in_page(cr3) + index * 8, 8);
+	if (ret)
+		return 0;
+	return pdpte;
+}
+
 static void nested_svm_set_tdp_cr3(struct kvm_vcpu *vcpu,
 				   unsigned long root)
 {
@@ -1875,6 +1889,7 @@
 
 	vcpu->arch.mmu.set_cr3           = nested_svm_set_tdp_cr3;
 	vcpu->arch.mmu.get_cr3           = nested_svm_get_tdp_cr3;
+	vcpu->arch.mmu.get_pdptr         = nested_svm_get_tdp_pdptr;
 	vcpu->arch.mmu.inject_page_fault = nested_svm_inject_npf_exit;
 	vcpu->arch.mmu.shadow_root_level = get_npt_level();
 	vcpu->arch.walk_mmu              = &vcpu->arch.nested_mmu;