KVM: arm64: Ensure that SME is trapped if not supported in pKVM

Do not rely on trapping being set only by the host for SME when
running in protected mode.

Signed-off-by: Fuad Tabba <tabba@google.com>
diff --git a/arch/arm64/kvm/hyp/nvhe/hyp-main.c b/arch/arm64/kvm/hyp/nvhe/hyp-main.c
index 1fd419c..b34d726 100644
--- a/arch/arm64/kvm/hyp/nvhe/hyp-main.c
+++ b/arch/arm64/kvm/hyp/nvhe/hyp-main.c
@@ -1193,6 +1193,7 @@ void handle_trap(struct kvm_cpu_context *host_ctxt)
 		break;
 	case ESR_ELx_EC_FP_ASIMD:
 	case ESR_ELx_EC_SVE:
+	case ESR_ELx_EC_SME:
 		fpsimd_host_restore();
 		break;
 	case ESR_ELx_EC_IABT_LOW:
diff --git a/arch/arm64/kvm/hyp/nvhe/pkvm.c b/arch/arm64/kvm/hyp/nvhe/pkvm.c
index 199ad51..c2edb3e 100644
--- a/arch/arm64/kvm/hyp/nvhe/pkvm.c
+++ b/arch/arm64/kvm/hyp/nvhe/pkvm.c
@@ -102,6 +102,8 @@ static void pvm_init_traps_aa64pfr1(struct kvm_vcpu *vcpu)
 	const u64 feature_ids = pvm_read_id_reg(vcpu, SYS_ID_AA64PFR1_EL1);
 	u64 hcr_set = 0;
 	u64 hcr_clear = 0;
+	u64 cptr_set = 0;
+	u64 cptr_clear = 0;
 
 	/* Memory Tagging: Trap and Treat as Untagged if not supported. */
 	if (!FIELD_GET(ARM64_FEATURE_MASK(ID_AA64PFR1_EL1_MTE), feature_ids)) {
@@ -109,8 +111,17 @@ static void pvm_init_traps_aa64pfr1(struct kvm_vcpu *vcpu)
 		hcr_clear |= HCR_DCT | HCR_ATA;
 	}
 
+	/* No SME supprot in KVM. */
+	BUG_ON(FIELD_GET(ARM64_FEATURE_MASK(ID_AA64PFR1_EL1_SME), feature_ids));
+	if (has_hvhe())
+		cptr_clear |= CPACR_EL1_SMEN_EL1EN | CPACR_EL1_SMEN_EL0EN;
+	else
+		cptr_set |= CPTR_EL2_TSM;
+
 	vcpu->arch.hcr_el2 |= hcr_set;
 	vcpu->arch.hcr_el2 &= ~hcr_clear;
+	vcpu->arch.cptr_el2 |= cptr_set;
+	vcpu->arch.cptr_el2 &= ~cptr_clear;
 }
 
 /*
diff --git a/arch/arm64/kvm/hyp/nvhe/switch.c b/arch/arm64/kvm/hyp/nvhe/switch.c
index 31c4649..36b311f 100644
--- a/arch/arm64/kvm/hyp/nvhe/switch.c
+++ b/arch/arm64/kvm/hyp/nvhe/switch.c
@@ -200,6 +200,7 @@ static const exit_handler_fn pvm_exit_handlers[] = {
 	[ESR_ELx_EC_HVC64]		= kvm_handle_pvm_hvc64,
 	[ESR_ELx_EC_SYS64]		= kvm_handle_pvm_sys64,
 	[ESR_ELx_EC_SVE]		= kvm_handle_pvm_restricted,
+	[ESR_ELx_EC_SME]		= kvm_handle_pvm_restricted,
 	[ESR_ELx_EC_FP_ASIMD]		= kvm_hyp_handle_fpsimd,
 	[ESR_ELx_EC_IABT_LOW]		= kvm_hyp_handle_iabt_low,
 	[ESR_ELx_EC_DABT_LOW]		= kvm_hyp_handle_dabt_low,