diff --git a/Documentation/virtual/kvm/cpuid.txt b/Documentation/virtual/kvm/cpuid.txt
index dcab6dc11e3b08117456f10903ad3ac2fa29eb99..87a7506f31c2b7f9b03be131ba3b3ef1c85f360b 100644
--- a/Documentation/virtual/kvm/cpuid.txt
+++ b/Documentation/virtual/kvm/cpuid.txt
@@ -58,6 +58,10 @@ KVM_FEATURE_PV_TLB_FLUSH           ||     9 || guest checks this feature bit
                                    ||       || before enabling paravirtualized
                                    ||       || tlb flush.
 ------------------------------------------------------------------------------
+KVM_FEATURE_ASYNC_PF_VMEXIT        ||    10 || paravirtualized async PF VM exit
+                                   ||       || can be enabled by setting bit 2
+                                   ||       || when writing to msr 0x4b564d02
+------------------------------------------------------------------------------
 KVM_FEATURE_CLOCKSOURCE_STABLE_BIT ||    24 || host will warn if no guest-side
                                    ||       || per-cpu warps are expected in
                                    ||       || kvmclock.
diff --git a/Documentation/virtual/kvm/msr.txt b/Documentation/virtual/kvm/msr.txt
index 1ebecc115dc6efdbbfa76eb4ab329f5b1d8b765e..f3f0d57ced8e1827fe8e8e6dc49808b18c9253fb 100644
--- a/Documentation/virtual/kvm/msr.txt
+++ b/Documentation/virtual/kvm/msr.txt
@@ -170,7 +170,8 @@ MSR_KVM_ASYNC_PF_EN: 0x4b564d02
 	when asynchronous page faults are enabled on the vcpu 0 when
 	disabled. Bit 1 is 1 if asynchronous page faults can be injected
 	when vcpu is in cpl == 0. Bit 2 is 1 if asynchronous page faults
-	are delivered to L1 as #PF vmexits.
+	are delivered to L1 as #PF vmexits.  Bit 2 can be set only if
+	KVM_FEATURE_ASYNC_PF_VMEXIT is present in CPUID.
 
 	First 4 byte of 64 byte memory location will be written to by
 	the hypervisor at the time of asynchronous page fault (APF)
diff --git a/arch/arm/kvm/hyp/Makefile b/arch/arm/kvm/hyp/Makefile
index 5638ce0c95241f7f3afd3e994aaa9993347c9be7..63d6b404d88e39bf581c0434aff5d6b0e951279f 100644
--- a/arch/arm/kvm/hyp/Makefile
+++ b/arch/arm/kvm/hyp/Makefile
@@ -7,6 +7,8 @@ ccflags-y += -fno-stack-protector -DDISABLE_BRANCH_PROFILING
 
 KVM=../../../../virt/kvm
 
+CFLAGS_ARMV7VE		   :=$(call cc-option, -march=armv7ve)
+
 obj-$(CONFIG_KVM_ARM_HOST) += $(KVM)/arm/hyp/vgic-v2-sr.o
 obj-$(CONFIG_KVM_ARM_HOST) += $(KVM)/arm/hyp/vgic-v3-sr.o
 obj-$(CONFIG_KVM_ARM_HOST) += $(KVM)/arm/hyp/timer-sr.o
@@ -15,7 +17,10 @@ obj-$(CONFIG_KVM_ARM_HOST) += tlb.o
 obj-$(CONFIG_KVM_ARM_HOST) += cp15-sr.o
 obj-$(CONFIG_KVM_ARM_HOST) += vfp.o
 obj-$(CONFIG_KVM_ARM_HOST) += banked-sr.o
+CFLAGS_banked-sr.o	   += $(CFLAGS_ARMV7VE)
+
 obj-$(CONFIG_KVM_ARM_HOST) += entry.o
 obj-$(CONFIG_KVM_ARM_HOST) += hyp-entry.o
 obj-$(CONFIG_KVM_ARM_HOST) += switch.o
+CFLAGS_switch.o		   += $(CFLAGS_ARMV7VE)
 obj-$(CONFIG_KVM_ARM_HOST) += s2-setup.o
diff --git a/arch/arm/kvm/hyp/banked-sr.c b/arch/arm/kvm/hyp/banked-sr.c
index 111bda8cdebdc7e59789103838087920aedf0efe..be4b8b0a40ade5c8e412bbc75f2b311c389aa979 100644
--- a/arch/arm/kvm/hyp/banked-sr.c
+++ b/arch/arm/kvm/hyp/banked-sr.c
@@ -20,6 +20,10 @@
 
 #include <asm/kvm_hyp.h>
 
+/*
+ * gcc before 4.9 doesn't understand -march=armv7ve, so we have to
+ * trick the assembler.
+ */
 __asm__(".arch_extension     virt");
 
 void __hyp_text __banked_save_state(struct kvm_cpu_context *ctxt)
diff --git a/arch/s390/kvm/intercept.c b/arch/s390/kvm/intercept.c
index 9c7d707158622e7f0743570db60599fa95a9de3b..07c6e81163bf5e248c1b744f69b75548aeae5d44 100644
--- a/arch/s390/kvm/intercept.c
+++ b/arch/s390/kvm/intercept.c
@@ -22,22 +22,6 @@
 #include "trace.h"
 #include "trace-s390.h"
 
-
-static const intercept_handler_t instruction_handlers[256] = {
-	[0x01] = kvm_s390_handle_01,
-	[0x82] = kvm_s390_handle_lpsw,
-	[0x83] = kvm_s390_handle_diag,
-	[0xaa] = kvm_s390_handle_aa,
-	[0xae] = kvm_s390_handle_sigp,
-	[0xb2] = kvm_s390_handle_b2,
-	[0xb6] = kvm_s390_handle_stctl,
-	[0xb7] = kvm_s390_handle_lctl,
-	[0xb9] = kvm_s390_handle_b9,
-	[0xe3] = kvm_s390_handle_e3,
-	[0xe5] = kvm_s390_handle_e5,
-	[0xeb] = kvm_s390_handle_eb,
-};
-
 u8 kvm_s390_get_ilen(struct kvm_vcpu *vcpu)
 {
 	struct kvm_s390_sie_block *sie_block = vcpu->arch.sie_block;
@@ -129,16 +113,39 @@ static int handle_validity(struct kvm_vcpu *vcpu)
 
 static int handle_instruction(struct kvm_vcpu *vcpu)
 {
-	intercept_handler_t handler;
-
 	vcpu->stat.exit_instruction++;
 	trace_kvm_s390_intercept_instruction(vcpu,
 					     vcpu->arch.sie_block->ipa,
 					     vcpu->arch.sie_block->ipb);
-	handler = instruction_handlers[vcpu->arch.sie_block->ipa >> 8];
-	if (handler)
-		return handler(vcpu);
-	return -EOPNOTSUPP;
+
+	switch (vcpu->arch.sie_block->ipa >> 8) {
+	case 0x01:
+		return kvm_s390_handle_01(vcpu);
+	case 0x82:
+		return kvm_s390_handle_lpsw(vcpu);
+	case 0x83:
+		return kvm_s390_handle_diag(vcpu);
+	case 0xaa:
+		return kvm_s390_handle_aa(vcpu);
+	case 0xae:
+		return kvm_s390_handle_sigp(vcpu);
+	case 0xb2:
+		return kvm_s390_handle_b2(vcpu);
+	case 0xb6:
+		return kvm_s390_handle_stctl(vcpu);
+	case 0xb7:
+		return kvm_s390_handle_lctl(vcpu);
+	case 0xb9:
+		return kvm_s390_handle_b9(vcpu);
+	case 0xe3:
+		return kvm_s390_handle_e3(vcpu);
+	case 0xe5:
+		return kvm_s390_handle_e5(vcpu);
+	case 0xeb:
+		return kvm_s390_handle_eb(vcpu);
+	default:
+		return -EOPNOTSUPP;
+	}
 }
 
 static int inject_prog_on_prog_intercept(struct kvm_vcpu *vcpu)
diff --git a/arch/s390/kvm/interrupt.c b/arch/s390/kvm/interrupt.c
index aabf46f5f883d44d71cddc88b5ef28ec677ff200..b04616b57a94713a24ed19b142adb94b8c96e748 100644
--- a/arch/s390/kvm/interrupt.c
+++ b/arch/s390/kvm/interrupt.c
@@ -169,8 +169,15 @@ static int ckc_interrupts_enabled(struct kvm_vcpu *vcpu)
 
 static int ckc_irq_pending(struct kvm_vcpu *vcpu)
 {
-	if (vcpu->arch.sie_block->ckc >= kvm_s390_get_tod_clock_fast(vcpu->kvm))
+	const u64 now = kvm_s390_get_tod_clock_fast(vcpu->kvm);
+	const u64 ckc = vcpu->arch.sie_block->ckc;
+
+	if (vcpu->arch.sie_block->gcr[0] & 0x0020000000000000ul) {
+		if ((s64)ckc >= (s64)now)
+			return 0;
+	} else if (ckc >= now) {
 		return 0;
+	}
 	return ckc_interrupts_enabled(vcpu);
 }
 
@@ -187,12 +194,6 @@ static int cpu_timer_irq_pending(struct kvm_vcpu *vcpu)
 	return kvm_s390_get_cpu_timer(vcpu) >> 63;
 }
 
-static inline int is_ioirq(unsigned long irq_type)
-{
-	return ((irq_type >= IRQ_PEND_IO_ISC_7) &&
-		(irq_type <= IRQ_PEND_IO_ISC_0));
-}
-
 static uint64_t isc_to_isc_bits(int isc)
 {
 	return (0x80 >> isc) << 24;
@@ -236,10 +237,15 @@ static inline int kvm_s390_gisa_tac_ipm_gisc(struct kvm_s390_gisa *gisa, u32 gis
 	return test_and_clear_bit_inv(IPM_BIT_OFFSET + gisc, (unsigned long *) gisa);
 }
 
-static inline unsigned long pending_irqs(struct kvm_vcpu *vcpu)
+static inline unsigned long pending_irqs_no_gisa(struct kvm_vcpu *vcpu)
 {
 	return vcpu->kvm->arch.float_int.pending_irqs |
-		vcpu->arch.local_int.pending_irqs |
+		vcpu->arch.local_int.pending_irqs;
+}
+
+static inline unsigned long pending_irqs(struct kvm_vcpu *vcpu)
+{
+	return pending_irqs_no_gisa(vcpu) |
 		kvm_s390_gisa_get_ipm(vcpu->kvm->arch.gisa) << IRQ_PEND_IO_ISC_7;
 }
 
@@ -337,7 +343,7 @@ static void __reset_intercept_indicators(struct kvm_vcpu *vcpu)
 
 static void set_intercept_indicators_io(struct kvm_vcpu *vcpu)
 {
-	if (!(pending_irqs(vcpu) & IRQ_PEND_IO_MASK))
+	if (!(pending_irqs_no_gisa(vcpu) & IRQ_PEND_IO_MASK))
 		return;
 	else if (psw_ioint_disabled(vcpu))
 		kvm_s390_set_cpuflags(vcpu, CPUSTAT_IO_INT);
@@ -1011,24 +1017,6 @@ static int __must_check __deliver_io(struct kvm_vcpu *vcpu,
 	return rc;
 }
 
-typedef int (*deliver_irq_t)(struct kvm_vcpu *vcpu);
-
-static const deliver_irq_t deliver_irq_funcs[] = {
-	[IRQ_PEND_MCHK_EX]        = __deliver_machine_check,
-	[IRQ_PEND_MCHK_REP]       = __deliver_machine_check,
-	[IRQ_PEND_PROG]           = __deliver_prog,
-	[IRQ_PEND_EXT_EMERGENCY]  = __deliver_emergency_signal,
-	[IRQ_PEND_EXT_EXTERNAL]   = __deliver_external_call,
-	[IRQ_PEND_EXT_CLOCK_COMP] = __deliver_ckc,
-	[IRQ_PEND_EXT_CPU_TIMER]  = __deliver_cpu_timer,
-	[IRQ_PEND_RESTART]        = __deliver_restart,
-	[IRQ_PEND_SET_PREFIX]     = __deliver_set_prefix,
-	[IRQ_PEND_PFAULT_INIT]    = __deliver_pfault_init,
-	[IRQ_PEND_EXT_SERVICE]    = __deliver_service,
-	[IRQ_PEND_PFAULT_DONE]    = __deliver_pfault_done,
-	[IRQ_PEND_VIRTIO]         = __deliver_virtio,
-};
-
 /* Check whether an external call is pending (deliverable or not) */
 int kvm_s390_ext_call_pending(struct kvm_vcpu *vcpu)
 {
@@ -1066,13 +1054,19 @@ int kvm_cpu_has_pending_timer(struct kvm_vcpu *vcpu)
 
 static u64 __calculate_sltime(struct kvm_vcpu *vcpu)
 {
-	u64 now, cputm, sltime = 0;
+	const u64 now = kvm_s390_get_tod_clock_fast(vcpu->kvm);
+	const u64 ckc = vcpu->arch.sie_block->ckc;
+	u64 cputm, sltime = 0;
 
 	if (ckc_interrupts_enabled(vcpu)) {
-		now = kvm_s390_get_tod_clock_fast(vcpu->kvm);
-		sltime = tod_to_ns(vcpu->arch.sie_block->ckc - now);
-		/* already expired or overflow? */
-		if (!sltime || vcpu->arch.sie_block->ckc <= now)
+		if (vcpu->arch.sie_block->gcr[0] & 0x0020000000000000ul) {
+			if ((s64)now < (s64)ckc)
+				sltime = tod_to_ns((s64)ckc - (s64)now);
+		} else if (now < ckc) {
+			sltime = tod_to_ns(ckc - now);
+		}
+		/* already expired */
+		if (!sltime)
 			return 0;
 		if (cpu_timer_interrupts_enabled(vcpu)) {
 			cputm = kvm_s390_get_cpu_timer(vcpu);
@@ -1192,7 +1186,6 @@ void kvm_s390_clear_local_irqs(struct kvm_vcpu *vcpu)
 int __must_check kvm_s390_deliver_pending_interrupts(struct kvm_vcpu *vcpu)
 {
 	struct kvm_s390_local_interrupt *li = &vcpu->arch.local_int;
-	deliver_irq_t func;
 	int rc = 0;
 	unsigned long irq_type;
 	unsigned long irqs;
@@ -1212,16 +1205,57 @@ int __must_check kvm_s390_deliver_pending_interrupts(struct kvm_vcpu *vcpu)
 	while ((irqs = deliverable_irqs(vcpu)) && !rc) {
 		/* bits are in the reverse order of interrupt priority */
 		irq_type = find_last_bit(&irqs, IRQ_PEND_COUNT);
-		if (is_ioirq(irq_type)) {
+		switch (irq_type) {
+		case IRQ_PEND_IO_ISC_0:
+		case IRQ_PEND_IO_ISC_1:
+		case IRQ_PEND_IO_ISC_2:
+		case IRQ_PEND_IO_ISC_3:
+		case IRQ_PEND_IO_ISC_4:
+		case IRQ_PEND_IO_ISC_5:
+		case IRQ_PEND_IO_ISC_6:
+		case IRQ_PEND_IO_ISC_7:
 			rc = __deliver_io(vcpu, irq_type);
-		} else {
-			func = deliver_irq_funcs[irq_type];
-			if (!func) {
-				WARN_ON_ONCE(func == NULL);
-				clear_bit(irq_type, &li->pending_irqs);
-				continue;
-			}
-			rc = func(vcpu);
+			break;
+		case IRQ_PEND_MCHK_EX:
+		case IRQ_PEND_MCHK_REP:
+			rc = __deliver_machine_check(vcpu);
+			break;
+		case IRQ_PEND_PROG:
+			rc = __deliver_prog(vcpu);
+			break;
+		case IRQ_PEND_EXT_EMERGENCY:
+			rc = __deliver_emergency_signal(vcpu);
+			break;
+		case IRQ_PEND_EXT_EXTERNAL:
+			rc = __deliver_external_call(vcpu);
+			break;
+		case IRQ_PEND_EXT_CLOCK_COMP:
+			rc = __deliver_ckc(vcpu);
+			break;
+		case IRQ_PEND_EXT_CPU_TIMER:
+			rc = __deliver_cpu_timer(vcpu);
+			break;
+		case IRQ_PEND_RESTART:
+			rc = __deliver_restart(vcpu);
+			break;
+		case IRQ_PEND_SET_PREFIX:
+			rc = __deliver_set_prefix(vcpu);
+			break;
+		case IRQ_PEND_PFAULT_INIT:
+			rc = __deliver_pfault_init(vcpu);
+			break;
+		case IRQ_PEND_EXT_SERVICE:
+			rc = __deliver_service(vcpu);
+			break;
+		case IRQ_PEND_PFAULT_DONE:
+			rc = __deliver_pfault_done(vcpu);
+			break;
+		case IRQ_PEND_VIRTIO:
+			rc = __deliver_virtio(vcpu);
+			break;
+		default:
+			WARN_ONCE(1, "Unknown pending irq type %ld", irq_type);
+			clear_bit(irq_type, &li->pending_irqs);
 		}
 	}
 
@@ -1701,7 +1735,8 @@ static void __floating_irq_kick(struct kvm *kvm, u64 type)
 		kvm_s390_set_cpuflags(dst_vcpu, CPUSTAT_STOP_INT);
 		break;
 	case KVM_S390_INT_IO_MIN...KVM_S390_INT_IO_MAX:
-		kvm_s390_set_cpuflags(dst_vcpu, CPUSTAT_IO_INT);
+		if (!(type & KVM_S390_INT_IO_AI_MASK && kvm->arch.gisa))
+			kvm_s390_set_cpuflags(dst_vcpu, CPUSTAT_IO_INT);
 		break;
 	default:
 		kvm_s390_set_cpuflags(dst_vcpu, CPUSTAT_EXT_INT);
diff --git a/arch/s390/kvm/kvm-s390.c b/arch/s390/kvm/kvm-s390.c
index ba4c7092335ad254fe0d385b99aa231dbecd2d53..77d7818130db49cb0108acf68a3d737e064cb918 100644
--- a/arch/s390/kvm/kvm-s390.c
+++ b/arch/s390/kvm/kvm-s390.c
@@ -179,6 +179,28 @@ int kvm_arch_hardware_enable(void)
 static void kvm_gmap_notifier(struct gmap *gmap, unsigned long start,
 			      unsigned long end);
 
+static void kvm_clock_sync_scb(struct kvm_s390_sie_block *scb, u64 delta)
+{
+	u8 delta_idx = 0;
+
+	/*
+	 * The TOD jumps by delta, we have to compensate this by adding
+	 * -delta to the epoch.
+	 */
+	delta = -delta;
+
+	/* sign-extension - we're adding to signed values below */
+	if ((s64)delta < 0)
+		delta_idx = -1;
+
+	scb->epoch += delta;
+	if (scb->ecd & ECD_MEF) {
+		scb->epdx += delta_idx;
+		if (scb->epoch < delta)
+			scb->epdx += 1;
+	}
+}
+
 /*
  * This callback is executed during stop_machine(). All CPUs are therefore
  * temporarily stopped. In order not to change guest behavior, we have to
@@ -194,13 +216,17 @@ static int kvm_clock_sync(struct notifier_block *notifier, unsigned long val,
 	unsigned long long *delta = v;
 
 	list_for_each_entry(kvm, &vm_list, vm_list) {
-		kvm->arch.epoch -= *delta;
 		kvm_for_each_vcpu(i, vcpu, kvm) {
-			vcpu->arch.sie_block->epoch -= *delta;
+			kvm_clock_sync_scb(vcpu->arch.sie_block, *delta);
+			if (i == 0) {
+				kvm->arch.epoch = vcpu->arch.sie_block->epoch;
+				kvm->arch.epdx = vcpu->arch.sie_block->epdx;
+			}
 			if (vcpu->arch.cputm_enabled)
 				vcpu->arch.cputm_start += *delta;
 			if (vcpu->arch.vsie_block)
-				vcpu->arch.vsie_block->epoch -= *delta;
+				kvm_clock_sync_scb(vcpu->arch.vsie_block,
+						   *delta);
 		}
 	}
 	return NOTIFY_OK;
@@ -902,12 +928,9 @@ static int kvm_s390_set_tod_ext(struct kvm *kvm, struct kvm_device_attr *attr)
 	if (copy_from_user(&gtod, (void __user *)attr->addr, sizeof(gtod)))
 		return -EFAULT;
 
-	if (test_kvm_facility(kvm, 139))
-		kvm_s390_set_tod_clock_ext(kvm, &gtod);
-	else if (gtod.epoch_idx == 0)
-		kvm_s390_set_tod_clock(kvm, gtod.tod);
-	else
+	if (!test_kvm_facility(kvm, 139) && gtod.epoch_idx)
 		return -EINVAL;
+	kvm_s390_set_tod_clock(kvm, &gtod);
 
 	VM_EVENT(kvm, 3, "SET: TOD extension: 0x%x, TOD base: 0x%llx",
 		gtod.epoch_idx, gtod.tod);
@@ -932,13 +955,14 @@ static int kvm_s390_set_tod_high(struct kvm *kvm, struct kvm_device_attr *attr)
 
 static int kvm_s390_set_tod_low(struct kvm *kvm, struct kvm_device_attr *attr)
 {
-	u64 gtod;
+	struct kvm_s390_vm_tod_clock gtod = { 0 };
 
-	if (copy_from_user(&gtod, (void __user *)attr->addr, sizeof(gtod)))
+	if (copy_from_user(&gtod.tod, (void __user *)attr->addr,
+			   sizeof(gtod.tod)))
 		return -EFAULT;
 
-	kvm_s390_set_tod_clock(kvm, gtod);
-	VM_EVENT(kvm, 3, "SET: TOD base: 0x%llx", gtod);
+	kvm_s390_set_tod_clock(kvm, &gtod);
+	VM_EVENT(kvm, 3, "SET: TOD base: 0x%llx", gtod.tod);
 	return 0;
 }
 
@@ -2389,6 +2413,7 @@ void kvm_arch_vcpu_postcreate(struct kvm_vcpu *vcpu)
 	mutex_lock(&vcpu->kvm->lock);
 	preempt_disable();
 	vcpu->arch.sie_block->epoch = vcpu->kvm->arch.epoch;
+	vcpu->arch.sie_block->epdx = vcpu->kvm->arch.epdx;
 	preempt_enable();
 	mutex_unlock(&vcpu->kvm->lock);
 	if (!kvm_is_ucontrol(vcpu->kvm)) {
@@ -3021,8 +3046,8 @@ static int kvm_s390_handle_requests(struct kvm_vcpu *vcpu)
 	return 0;
 }
 
-void kvm_s390_set_tod_clock_ext(struct kvm *kvm,
-				 const struct kvm_s390_vm_tod_clock *gtod)
+void kvm_s390_set_tod_clock(struct kvm *kvm,
+			    const struct kvm_s390_vm_tod_clock *gtod)
 {
 	struct kvm_vcpu *vcpu;
 	struct kvm_s390_tod_clock_ext htod;
@@ -3034,10 +3059,12 @@ void kvm_s390_set_tod_clock_ext(struct kvm *kvm,
 	get_tod_clock_ext((char *)&htod);
 
 	kvm->arch.epoch = gtod->tod - htod.tod;
-	kvm->arch.epdx = gtod->epoch_idx - htod.epoch_idx;
-
-	if (kvm->arch.epoch > gtod->tod)
-		kvm->arch.epdx -= 1;
+	kvm->arch.epdx = 0;
+	if (test_kvm_facility(kvm, 139)) {
+		kvm->arch.epdx = gtod->epoch_idx - htod.epoch_idx;
+		if (kvm->arch.epoch > gtod->tod)
+			kvm->arch.epdx -= 1;
+	}
 
 	kvm_s390_vcpu_block_all(kvm);
 	kvm_for_each_vcpu(i, vcpu, kvm) {
@@ -3050,22 +3077,6 @@ void kvm_s390_set_tod_clock_ext(struct kvm *kvm,
 	mutex_unlock(&kvm->lock);
 }
 
-void kvm_s390_set_tod_clock(struct kvm *kvm, u64 tod)
-{
-	struct kvm_vcpu *vcpu;
-	int i;
-
-	mutex_lock(&kvm->lock);
-	preempt_disable();
-	kvm->arch.epoch = tod - get_tod_clock();
-	kvm_s390_vcpu_block_all(kvm);
-	kvm_for_each_vcpu(i, vcpu, kvm)
-		vcpu->arch.sie_block->epoch = kvm->arch.epoch;
-	kvm_s390_vcpu_unblock_all(kvm);
-	preempt_enable();
-	mutex_unlock(&kvm->lock);
-}
-
 /**
  * kvm_arch_fault_in_page - fault-in guest page if necessary
  * @vcpu: The corresponding virtual cpu
diff --git a/arch/s390/kvm/kvm-s390.h b/arch/s390/kvm/kvm-s390.h
index bd31b37b0e6f83905e7204b2eb439050aaeb1187..f55ac0ef99ea70bf3fb1d8f7eb7e303eb405b2b4 100644
--- a/arch/s390/kvm/kvm-s390.h
+++ b/arch/s390/kvm/kvm-s390.h
@@ -19,8 +19,6 @@
 #include <asm/processor.h>
 #include <asm/sclp.h>
 
-typedef int (*intercept_handler_t)(struct kvm_vcpu *vcpu);
-
 /* Transactional Memory Execution related macros */
 #define IS_TE_ENABLED(vcpu)	((vcpu->arch.sie_block->ecb & ECB_TE))
 #define TDB_FORMAT1		1
@@ -283,9 +281,8 @@ int kvm_s390_handle_sigp(struct kvm_vcpu *vcpu);
 int kvm_s390_handle_sigp_pei(struct kvm_vcpu *vcpu);
 
 /* implemented in kvm-s390.c */
-void kvm_s390_set_tod_clock_ext(struct kvm *kvm,
-				 const struct kvm_s390_vm_tod_clock *gtod);
-void kvm_s390_set_tod_clock(struct kvm *kvm, u64 tod);
+void kvm_s390_set_tod_clock(struct kvm *kvm,
+			    const struct kvm_s390_vm_tod_clock *gtod);
 long kvm_arch_fault_in_page(struct kvm_vcpu *vcpu, gpa_t gpa, int writable);
 int kvm_s390_store_status_unloaded(struct kvm_vcpu *vcpu, unsigned long addr);
 int kvm_s390_vcpu_store_status(struct kvm_vcpu *vcpu, unsigned long addr);
diff --git a/arch/s390/kvm/priv.c b/arch/s390/kvm/priv.c
index c4c4e157c03631a1d745ccafdfec111a6a08a49e..f0b4185158afcb229cac5b06f23f5664afffebec 100644
--- a/arch/s390/kvm/priv.c
+++ b/arch/s390/kvm/priv.c
@@ -85,9 +85,10 @@ int kvm_s390_handle_e3(struct kvm_vcpu *vcpu)
 /* Handle SCK (SET CLOCK) interception */
 static int handle_set_clock(struct kvm_vcpu *vcpu)
 {
+	struct kvm_s390_vm_tod_clock gtod = { 0 };
 	int rc;
 	u8 ar;
-	u64 op2, val;
+	u64 op2;
 
 	vcpu->stat.instruction_sck++;
 
@@ -97,12 +98,12 @@ static int handle_set_clock(struct kvm_vcpu *vcpu)
 	op2 = kvm_s390_get_base_disp_s(vcpu, &ar);
 	if (op2 & 7)	/* Operand must be on a doubleword boundary */
 		return kvm_s390_inject_program_int(vcpu, PGM_SPECIFICATION);
-	rc = read_guest(vcpu, op2, ar, &val, sizeof(val));
+	rc = read_guest(vcpu, op2, ar, &gtod.tod, sizeof(gtod.tod));
 	if (rc)
 		return kvm_s390_inject_prog_cond(vcpu, rc);
 
-	VCPU_EVENT(vcpu, 3, "SCK: setting guest TOD to 0x%llx", val);
-	kvm_s390_set_tod_clock(vcpu->kvm, val);
+	VCPU_EVENT(vcpu, 3, "SCK: setting guest TOD to 0x%llx", gtod.tod);
+	kvm_s390_set_tod_clock(vcpu->kvm, &gtod);
 
 	kvm_s390_set_psw_cc(vcpu, 0);
 	return 0;
@@ -795,55 +796,60 @@ static int handle_stsi(struct kvm_vcpu *vcpu)
 	return rc;
 }
 
-static const intercept_handler_t b2_handlers[256] = {
-	[0x02] = handle_stidp,
-	[0x04] = handle_set_clock,
-	[0x10] = handle_set_prefix,
-	[0x11] = handle_store_prefix,
-	[0x12] = handle_store_cpu_address,
-	[0x14] = kvm_s390_handle_vsie,
-	[0x21] = handle_ipte_interlock,
-	[0x29] = handle_iske,
-	[0x2a] = handle_rrbe,
-	[0x2b] = handle_sske,
-	[0x2c] = handle_test_block,
-	[0x30] = handle_io_inst,
-	[0x31] = handle_io_inst,
-	[0x32] = handle_io_inst,
-	[0x33] = handle_io_inst,
-	[0x34] = handle_io_inst,
-	[0x35] = handle_io_inst,
-	[0x36] = handle_io_inst,
-	[0x37] = handle_io_inst,
-	[0x38] = handle_io_inst,
-	[0x39] = handle_io_inst,
-	[0x3a] = handle_io_inst,
-	[0x3b] = handle_io_inst,
-	[0x3c] = handle_io_inst,
-	[0x50] = handle_ipte_interlock,
-	[0x56] = handle_sthyi,
-	[0x5f] = handle_io_inst,
-	[0x74] = handle_io_inst,
-	[0x76] = handle_io_inst,
-	[0x7d] = handle_stsi,
-	[0xb1] = handle_stfl,
-	[0xb2] = handle_lpswe,
-};
-
 int kvm_s390_handle_b2(struct kvm_vcpu *vcpu)
 {
-	intercept_handler_t handler;
-
-	/*
-	 * A lot of B2 instructions are priviledged. Here we check for
-	 * the privileged ones, that we can handle in the kernel.
-	 * Anything else goes to userspace.
-	 */
-	handler = b2_handlers[vcpu->arch.sie_block->ipa & 0x00ff];
-	if (handler)
-		return handler(vcpu);
-
-	return -EOPNOTSUPP;
+	switch (vcpu->arch.sie_block->ipa & 0x00ff) {
+	case 0x02:
+		return handle_stidp(vcpu);
+	case 0x04:
+		return handle_set_clock(vcpu);
+	case 0x10:
+		return handle_set_prefix(vcpu);
+	case 0x11:
+		return handle_store_prefix(vcpu);
+	case 0x12:
+		return handle_store_cpu_address(vcpu);
+	case 0x14:
+		return kvm_s390_handle_vsie(vcpu);
+	case 0x21:
+	case 0x50:
+		return handle_ipte_interlock(vcpu);
+	case 0x29:
+		return handle_iske(vcpu);
+	case 0x2a:
+		return handle_rrbe(vcpu);
+	case 0x2b:
+		return handle_sske(vcpu);
+	case 0x2c:
+		return handle_test_block(vcpu);
+	case 0x30:
+	case 0x31:
+	case 0x32:
+	case 0x33:
+	case 0x34:
+	case 0x35:
+	case 0x36:
+	case 0x37:
+	case 0x38:
+	case 0x39:
+	case 0x3a:
+	case 0x3b:
+	case 0x3c:
+	case 0x5f:
+	case 0x74:
+	case 0x76:
+		return handle_io_inst(vcpu);
+	case 0x56:
+		return handle_sthyi(vcpu);
+	case 0x7d:
+		return handle_stsi(vcpu);
+	case 0xb1:
+		return handle_stfl(vcpu);
+	case 0xb2:
+		return handle_lpswe(vcpu);
+	default:
+		return -EOPNOTSUPP;
+	}
 }
 
 static int handle_epsw(struct kvm_vcpu *vcpu)
@@ -1105,25 +1111,22 @@ static int handle_essa(struct kvm_vcpu *vcpu)
 	return 0;
 }
 
-static const intercept_handler_t b9_handlers[256] = {
-	[0x8a] = handle_ipte_interlock,
-	[0x8d] = handle_epsw,
-	[0x8e] = handle_ipte_interlock,
-	[0x8f] = handle_ipte_interlock,
-	[0xab] = handle_essa,
-	[0xaf] = handle_pfmf,
-};
-
 int kvm_s390_handle_b9(struct kvm_vcpu *vcpu)
 {
-	intercept_handler_t handler;
-
-	/* This is handled just as for the B2 instructions. */
-	handler = b9_handlers[vcpu->arch.sie_block->ipa & 0x00ff];
-	if (handler)
-		return handler(vcpu);
-
-	return -EOPNOTSUPP;
+	switch (vcpu->arch.sie_block->ipa & 0x00ff) {
+	case 0x8a:
+	case 0x8e:
+	case 0x8f:
+		return handle_ipte_interlock(vcpu);
+	case 0x8d:
+		return handle_epsw(vcpu);
+	case 0xab:
+		return handle_essa(vcpu);
+	case 0xaf:
+		return handle_pfmf(vcpu);
+	default:
+		return -EOPNOTSUPP;
+	}
 }
 
 int kvm_s390_handle_lctl(struct kvm_vcpu *vcpu)
@@ -1271,22 +1274,20 @@ static int handle_stctg(struct kvm_vcpu *vcpu)
 	return rc ? kvm_s390_inject_prog_cond(vcpu, rc) : 0;
 }
 
-static const intercept_handler_t eb_handlers[256] = {
-	[0x2f] = handle_lctlg,
-	[0x25] = handle_stctg,
-	[0x60] = handle_ri,
-	[0x61] = handle_ri,
-	[0x62] = handle_ri,
-};
-
 int kvm_s390_handle_eb(struct kvm_vcpu *vcpu)
 {
-	intercept_handler_t handler;
-
-	handler = eb_handlers[vcpu->arch.sie_block->ipb & 0xff];
-	if (handler)
-		return handler(vcpu);
-	return -EOPNOTSUPP;
+	switch (vcpu->arch.sie_block->ipb & 0x000000ff) {
+	case 0x25:
+		return handle_stctg(vcpu);
+	case 0x2f:
+		return handle_lctlg(vcpu);
+	case 0x60:
+	case 0x61:
+	case 0x62:
+		return handle_ri(vcpu);
+	default:
+		return -EOPNOTSUPP;
+	}
 }
 
 static int handle_tprot(struct kvm_vcpu *vcpu)
@@ -1346,10 +1347,12 @@ static int handle_tprot(struct kvm_vcpu *vcpu)
 
 int kvm_s390_handle_e5(struct kvm_vcpu *vcpu)
 {
-	/* For e5xx... instructions we only handle TPROT */
-	if ((vcpu->arch.sie_block->ipa & 0x00ff) == 0x01)
+	switch (vcpu->arch.sie_block->ipa & 0x00ff) {
+	case 0x01:
 		return handle_tprot(vcpu);
-	return -EOPNOTSUPP;
+	default:
+		return -EOPNOTSUPP;
+	}
 }
 
 static int handle_sckpf(struct kvm_vcpu *vcpu)
@@ -1380,17 +1383,14 @@ static int handle_ptff(struct kvm_vcpu *vcpu)
 	return 0;
 }
 
-static const intercept_handler_t x01_handlers[256] = {
-	[0x04] = handle_ptff,
-	[0x07] = handle_sckpf,
-};
-
 int kvm_s390_handle_01(struct kvm_vcpu *vcpu)
 {
-	intercept_handler_t handler;
-
-	handler = x01_handlers[vcpu->arch.sie_block->ipa & 0x00ff];
-	if (handler)
-		return handler(vcpu);
-	return -EOPNOTSUPP;
+	switch (vcpu->arch.sie_block->ipa & 0x00ff) {
+	case 0x04:
+		return handle_ptff(vcpu);
+	case 0x07:
+		return handle_sckpf(vcpu);
+	default:
+		return -EOPNOTSUPP;
+	}
 }
diff --git a/arch/s390/kvm/vsie.c b/arch/s390/kvm/vsie.c
index ec772700ff9659350543534272a92eb531bbd181..8961e3970901d4b06c87b20f147115b683ad5170 100644
--- a/arch/s390/kvm/vsie.c
+++ b/arch/s390/kvm/vsie.c
@@ -821,6 +821,7 @@ static int do_vsie_run(struct kvm_vcpu *vcpu, struct vsie_page *vsie_page)
 {
 	struct kvm_s390_sie_block *scb_s = &vsie_page->scb_s;
 	struct kvm_s390_sie_block *scb_o = vsie_page->scb_o;
+	int guest_bp_isolation;
 	int rc;
 
 	handle_last_fault(vcpu, vsie_page);
@@ -831,6 +832,20 @@ static int do_vsie_run(struct kvm_vcpu *vcpu, struct vsie_page *vsie_page)
 		s390_handle_mcck();
 
 	srcu_read_unlock(&vcpu->kvm->srcu, vcpu->srcu_idx);
+
+	/* save current guest state of bp isolation override */
+	guest_bp_isolation = test_thread_flag(TIF_ISOLATE_BP_GUEST);
+
+	/*
+	 * The guest is running with BPBC, so we have to force it on for our
+	 * nested guest. This is done by enabling BPBC globally, so the BPBC
+	 * control in the SCB (which the nested guest can modify) is simply
+	 * ignored.
+	 */
+	if (test_kvm_facility(vcpu->kvm, 82) &&
+	    vcpu->arch.sie_block->fpf & FPF_BPBC)
+		set_thread_flag(TIF_ISOLATE_BP_GUEST);
+
 	local_irq_disable();
 	guest_enter_irqoff();
 	local_irq_enable();
@@ -840,6 +855,11 @@ static int do_vsie_run(struct kvm_vcpu *vcpu, struct vsie_page *vsie_page)
 	local_irq_disable();
 	guest_exit_irqoff();
 	local_irq_enable();
+
+	/* restore guest state for bp isolation override */
+	if (!guest_bp_isolation)
+		clear_thread_flag(TIF_ISOLATE_BP_GUEST);
+
 	vcpu->srcu_idx = srcu_read_lock(&vcpu->kvm->srcu);
 
 	if (rc == -EINTR) {
diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index dd6f57a54a2626c080c8505cd670ef8c54c7bf56..0a9e330b34f029fd780b444999b36da25a5d9c7f 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -1464,7 +1464,4 @@ static inline int kvm_cpu_get_apicid(int mps_cpu)
 #define put_smstate(type, buf, offset, val)                      \
 	*(type *)((buf) + (offset) - 0x7e00) = val
 
-void kvm_arch_mmu_notifier_invalidate_range(struct kvm *kvm,
-		unsigned long start, unsigned long end);
-
 #endif /* _ASM_X86_KVM_HOST_H */
diff --git a/arch/x86/include/uapi/asm/kvm_para.h b/arch/x86/include/uapi/asm/kvm_para.h
index 7a2ade4aa235380a8c28af6934d30566bb24de73..6cfa9c8cb7d650bef22010de0eca622a86364a73 100644
--- a/arch/x86/include/uapi/asm/kvm_para.h
+++ b/arch/x86/include/uapi/asm/kvm_para.h
@@ -26,6 +26,7 @@
 #define KVM_FEATURE_PV_EOI		6
 #define KVM_FEATURE_PV_UNHALT		7
 #define KVM_FEATURE_PV_TLB_FLUSH	9
+#define KVM_FEATURE_ASYNC_PF_VMEXIT	10
 
 /* The last 8 bits are used to indicate how to interpret the flags field
  * in pvclock structure. If no bits are set, all flags are ignored.
diff --git a/arch/x86/kernel/kvm.c b/arch/x86/kernel/kvm.c
index 4e37d1a851a62df3f9f841f3bbd66827af0c1920..bc1a27280c4bf77899afad4b85bf53212d385cab 100644
--- a/arch/x86/kernel/kvm.c
+++ b/arch/x86/kernel/kvm.c
@@ -49,7 +49,7 @@
 
 static int kvmapf = 1;
 
-static int parse_no_kvmapf(char *arg)
+static int __init parse_no_kvmapf(char *arg)
 {
         kvmapf = 0;
         return 0;
@@ -58,7 +58,7 @@ static int parse_no_kvmapf(char *arg)
 early_param("no-kvmapf", parse_no_kvmapf);
 
 static int steal_acc = 1;
-static int parse_no_stealacc(char *arg)
+static int __init parse_no_stealacc(char *arg)
 {
         steal_acc = 0;
         return 0;
@@ -67,7 +67,7 @@ static int parse_no_stealacc(char *arg)
 early_param("no-steal-acc", parse_no_stealacc);
 
 static int kvmclock_vsyscall = 1;
-static int parse_no_kvmclock_vsyscall(char *arg)
+static int __init parse_no_kvmclock_vsyscall(char *arg)
 {
         kvmclock_vsyscall = 0;
         return 0;
@@ -341,10 +341,10 @@ static void kvm_guest_cpu_init(void)
 #endif
 		pa |= KVM_ASYNC_PF_ENABLED;
 
-		/* Async page fault support for L1 hypervisor is optional */
-		if (wrmsr_safe(MSR_KVM_ASYNC_PF_EN,
-			(pa | KVM_ASYNC_PF_DELIVERY_AS_PF_VMEXIT) & 0xffffffff, pa >> 32) < 0)
-			wrmsrl(MSR_KVM_ASYNC_PF_EN, pa);
+		if (kvm_para_has_feature(KVM_FEATURE_ASYNC_PF_VMEXIT))
+			pa |= KVM_ASYNC_PF_DELIVERY_AS_PF_VMEXIT;
+
+		wrmsrl(MSR_KVM_ASYNC_PF_EN, pa);
 		__this_cpu_write(apf_reason.enabled, 1);
 		printk(KERN_INFO"KVM setup async PF for cpu %d\n",
 		       smp_processor_id());
@@ -545,7 +545,8 @@ static void __init kvm_guest_init(void)
 		pv_time_ops.steal_clock = kvm_steal_clock;
 	}
 
-	if (kvm_para_has_feature(KVM_FEATURE_PV_TLB_FLUSH))
+	if (kvm_para_has_feature(KVM_FEATURE_PV_TLB_FLUSH) &&
+	    !kvm_para_has_feature(KVM_FEATURE_STEAL_TIME))
 		pv_mmu_ops.flush_tlb_others = kvm_flush_tlb_others;
 
 	if (kvm_para_has_feature(KVM_FEATURE_PV_EOI))
@@ -633,7 +634,8 @@ static __init int kvm_setup_pv_tlb_flush(void)
 {
 	int cpu;
 
-	if (kvm_para_has_feature(KVM_FEATURE_PV_TLB_FLUSH)) {
+	if (kvm_para_has_feature(KVM_FEATURE_PV_TLB_FLUSH) &&
+	    !kvm_para_has_feature(KVM_FEATURE_STEAL_TIME)) {
 		for_each_possible_cpu(cpu) {
 			zalloc_cpumask_var_node(per_cpu_ptr(&__pv_tlb_mask, cpu),
 				GFP_KERNEL, cpu_to_node(cpu));
diff --git a/arch/x86/kvm/cpuid.c b/arch/x86/kvm/cpuid.c
index a0c5a69bc7c4a324078db14ad27753443d65aa85..b671fc2d0422717f06ca6291b2a76742ee45537a 100644
--- a/arch/x86/kvm/cpuid.c
+++ b/arch/x86/kvm/cpuid.c
@@ -607,7 +607,8 @@ static inline int __do_cpuid_ent(struct kvm_cpuid_entry2 *entry, u32 function,
 			     (1 << KVM_FEATURE_PV_EOI) |
 			     (1 << KVM_FEATURE_CLOCKSOURCE_STABLE_BIT) |
 			     (1 << KVM_FEATURE_PV_UNHALT) |
-			     (1 << KVM_FEATURE_PV_TLB_FLUSH);
+			     (1 << KVM_FEATURE_PV_TLB_FLUSH) |
+			     (1 << KVM_FEATURE_ASYNC_PF_VMEXIT);
 
 		if (sched_info_on())
 			entry->eax |= (1 << KVM_FEATURE_STEAL_TIME);
diff --git a/arch/x86/kvm/lapic.c b/arch/x86/kvm/lapic.c
index 924ac8ce9d5004f9db4126a81f178c2bf0e6ff40..cc5fe7a50dde2b13c2861d1616988d4319bd0af2 100644
--- a/arch/x86/kvm/lapic.c
+++ b/arch/x86/kvm/lapic.c
@@ -2165,7 +2165,6 @@ int kvm_create_lapic(struct kvm_vcpu *vcpu)
 	 */
 	vcpu->arch.apic_base = MSR_IA32_APICBASE_ENABLE;
 	static_key_slow_inc(&apic_sw_disabled.key); /* sw disabled at reset */
-	kvm_lapic_reset(vcpu, false);
 	kvm_iodevice_init(&apic->dev, &apic_mmio_ops);
 
 	return 0;
diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index 46ff304140c71fad1fa818324c0cda017257d2a5..f551962ac29488431b80d56aaabcec7a576a791f 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -3029,7 +3029,7 @@ static int kvm_handle_bad_page(struct kvm_vcpu *vcpu, gfn_t gfn, kvm_pfn_t pfn)
 		return RET_PF_RETRY;
 	}
 
-	return -EFAULT;
+	return RET_PF_EMULATE;
 }
 
 static void transparent_hugepage_adjust(struct kvm_vcpu *vcpu,
diff --git a/arch/x86/kvm/svm.c b/arch/x86/kvm/svm.c
index b3e488a748281aa5f3d319861ae2ab4bfca3e65c..3d8377f75eda2d1140a546c2d444be1d368f97b8 100644
--- a/arch/x86/kvm/svm.c
+++ b/arch/x86/kvm/svm.c
@@ -300,6 +300,8 @@ module_param(vgif, int, 0444);
 static int sev = IS_ENABLED(CONFIG_AMD_MEM_ENCRYPT_ACTIVE_BY_DEFAULT);
 module_param(sev, int, 0444);
 
+static u8 rsm_ins_bytes[] = "\x0f\xaa";
+
 static void svm_set_cr0(struct kvm_vcpu *vcpu, unsigned long cr0);
 static void svm_flush_tlb(struct kvm_vcpu *vcpu, bool invalidate_gpa);
 static void svm_complete_interrupts(struct vcpu_svm *svm);
@@ -1383,6 +1385,7 @@ static void init_vmcb(struct vcpu_svm *svm)
 	set_intercept(svm, INTERCEPT_SKINIT);
 	set_intercept(svm, INTERCEPT_WBINVD);
 	set_intercept(svm, INTERCEPT_XSETBV);
+	set_intercept(svm, INTERCEPT_RSM);
 
 	if (!kvm_mwait_in_guest()) {
 		set_intercept(svm, INTERCEPT_MONITOR);
@@ -3699,6 +3702,12 @@ static int emulate_on_interception(struct vcpu_svm *svm)
 	return emulate_instruction(&svm->vcpu, 0) == EMULATE_DONE;
 }
 
+static int rsm_interception(struct vcpu_svm *svm)
+{
+	return x86_emulate_instruction(&svm->vcpu, 0, 0,
+				       rsm_ins_bytes, 2) == EMULATE_DONE;
+}
+
 static int rdpmc_interception(struct vcpu_svm *svm)
 {
 	int err;
@@ -4541,7 +4550,7 @@ static int (*const svm_exit_handlers[])(struct vcpu_svm *svm) = {
 	[SVM_EXIT_MWAIT]			= mwait_interception,
 	[SVM_EXIT_XSETBV]			= xsetbv_interception,
 	[SVM_EXIT_NPF]				= npf_interception,
-	[SVM_EXIT_RSM]                          = emulate_on_interception,
+	[SVM_EXIT_RSM]                          = rsm_interception,
 	[SVM_EXIT_AVIC_INCOMPLETE_IPI]		= avic_incomplete_ipi_interception,
 	[SVM_EXIT_AVIC_UNACCELERATED_ACCESS]	= avic_unaccelerated_access_interception,
 };
@@ -6236,16 +6245,18 @@ static int sev_launch_update_data(struct kvm *kvm, struct kvm_sev_cmd *argp)
 
 static int sev_launch_measure(struct kvm *kvm, struct kvm_sev_cmd *argp)
 {
+	void __user *measure = (void __user *)(uintptr_t)argp->data;
 	struct kvm_sev_info *sev = &kvm->arch.sev_info;
 	struct sev_data_launch_measure *data;
 	struct kvm_sev_launch_measure params;
+	void __user *p = NULL;
 	void *blob = NULL;
 	int ret;
 
 	if (!sev_guest(kvm))
 		return -ENOTTY;
 
-	if (copy_from_user(&params, (void __user *)(uintptr_t)argp->data, sizeof(params)))
+	if (copy_from_user(&params, measure, sizeof(params)))
 		return -EFAULT;
 
 	data = kzalloc(sizeof(*data), GFP_KERNEL);
@@ -6256,17 +6267,13 @@ static int sev_launch_measure(struct kvm *kvm, struct kvm_sev_cmd *argp)
 	if (!params.len)
 		goto cmd;
 
-	if (params.uaddr) {
+	p = (void __user *)(uintptr_t)params.uaddr;
+	if (p) {
 		if (params.len > SEV_FW_BLOB_MAX_SIZE) {
 			ret = -EINVAL;
 			goto e_free;
 		}
 
-		if (!access_ok(VERIFY_WRITE, params.uaddr, params.len)) {
-			ret = -EFAULT;
-			goto e_free;
-		}
-
 		ret = -ENOMEM;
 		blob = kmalloc(params.len, GFP_KERNEL);
 		if (!blob)
@@ -6290,13 +6297,13 @@ static int sev_launch_measure(struct kvm *kvm, struct kvm_sev_cmd *argp)
 		goto e_free_blob;
 
 	if (blob) {
-		if (copy_to_user((void __user *)(uintptr_t)params.uaddr, blob, params.len))
+		if (copy_to_user(p, blob, params.len))
 			ret = -EFAULT;
 	}
 
 done:
 	params.len = data->len;
-	if (copy_to_user((void __user *)(uintptr_t)argp->data, &params, sizeof(params)))
+	if (copy_to_user(measure, &params, sizeof(params)))
 		ret = -EFAULT;
 e_free_blob:
 	kfree(blob);
@@ -6597,7 +6604,7 @@ static int sev_launch_secret(struct kvm *kvm, struct kvm_sev_cmd *argp)
 	struct page **pages;
 	void *blob, *hdr;
 	unsigned long n;
-	int ret;
+	int ret, offset;
 
 	if (!sev_guest(kvm))
 		return -ENOTTY;
@@ -6623,6 +6630,10 @@ static int sev_launch_secret(struct kvm *kvm, struct kvm_sev_cmd *argp)
 	if (!data)
 		goto e_unpin_memory;
 
+	offset = params.guest_uaddr & (PAGE_SIZE - 1);
+	data->guest_address = __sme_page_pa(pages[0]) + offset;
+	data->guest_len = params.guest_len;
+
 	blob = psp_copy_user_blob(params.trans_uaddr, params.trans_len);
 	if (IS_ERR(blob)) {
 		ret = PTR_ERR(blob);
@@ -6637,8 +6648,8 @@ static int sev_launch_secret(struct kvm *kvm, struct kvm_sev_cmd *argp)
 		ret = PTR_ERR(hdr);
 		goto e_free_blob;
 	}
-	data->trans_address = __psp_pa(blob);
-	data->trans_len = params.trans_len;
+	data->hdr_address = __psp_pa(hdr);
+	data->hdr_len = params.hdr_len;
 
 	data->handle = sev->handle;
 	ret = sev_issue_cmd(kvm, SEV_CMD_LAUNCH_UPDATE_SECRET, data, &argp->error);
diff --git a/arch/x86/kvm/vmx.c b/arch/x86/kvm/vmx.c
index 3dec126aa3022eb11f49d5de695d3658183a48ee..ec14f2319a87d259bf9761498e73b2e88d5d4791 100644
--- a/arch/x86/kvm/vmx.c
+++ b/arch/x86/kvm/vmx.c
@@ -4485,7 +4485,8 @@ static int vmx_set_cr4(struct kvm_vcpu *vcpu, unsigned long cr4)
 		vmcs_set_bits(SECONDARY_VM_EXEC_CONTROL,
 			      SECONDARY_EXEC_DESC);
 		hw_cr4 &= ~X86_CR4_UMIP;
-	} else
+	} else if (!is_guest_mode(vcpu) ||
+	           !nested_cpu_has2(get_vmcs12(vcpu), SECONDARY_EXEC_DESC))
 		vmcs_clear_bits(SECONDARY_VM_EXEC_CONTROL,
 				SECONDARY_EXEC_DESC);
 
@@ -11199,7 +11200,12 @@ static int nested_vmx_run(struct kvm_vcpu *vcpu, bool launch)
 	if (ret)
 		return ret;
 
-	if (vmcs12->guest_activity_state == GUEST_ACTIVITY_HLT)
+	/*
+	 * If we're entering a halted L2 vcpu and the L2 vcpu won't be woken
+	 * by event injection, halt vcpu.
+	 */
+	if ((vmcs12->guest_activity_state == GUEST_ACTIVITY_HLT) &&
+	    !(vmcs12->vm_entry_intr_info_field & INTR_INFO_VALID_MASK))
 		return kvm_vcpu_halt(vcpu);
 
 	vmx->nested.nested_run_pending = 1;
diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index c8a0b545ac20c71a464738a1dd0cd7e1c3df388e..96edda878dbf439297492531b45998ab55947113 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -7975,6 +7975,7 @@ int kvm_arch_vcpu_setup(struct kvm_vcpu *vcpu)
 	kvm_vcpu_mtrr_init(vcpu);
 	vcpu_load(vcpu);
 	kvm_vcpu_reset(vcpu, false);
+	kvm_lapic_reset(vcpu, false);
 	kvm_mmu_setup(vcpu);
 	vcpu_put(vcpu);
 	return 0;
@@ -8460,10 +8461,8 @@ int __x86_set_memory_region(struct kvm *kvm, int id, gpa_t gpa, u32 size)
 			return r;
 	}
 
-	if (!size) {
-		r = vm_munmap(old.userspace_addr, old.npages * PAGE_SIZE);
-		WARN_ON(r < 0);
-	}
+	if (!size)
+		vm_munmap(old.userspace_addr, old.npages * PAGE_SIZE);
 
 	return 0;
 }
diff --git a/drivers/crypto/ccp/psp-dev.c b/drivers/crypto/ccp/psp-dev.c
index fcfa5b1eae6169b44398b20442d95532d924821c..b3afb6cc9d72278b35eadb239e280b5aa311c3f9 100644
--- a/drivers/crypto/ccp/psp-dev.c
+++ b/drivers/crypto/ccp/psp-dev.c
@@ -211,7 +211,7 @@ static int __sev_platform_shutdown_locked(int *error)
 {
 	int ret;
 
-	ret = __sev_do_cmd_locked(SEV_CMD_SHUTDOWN, 0, error);
+	ret = __sev_do_cmd_locked(SEV_CMD_SHUTDOWN, NULL, error);
 	if (ret)
 		return ret;
 
@@ -271,7 +271,7 @@ static int sev_ioctl_do_reset(struct sev_issue_cmd *argp)
 			return rc;
 	}
 
-	return __sev_do_cmd_locked(SEV_CMD_FACTORY_RESET, 0, &argp->error);
+	return __sev_do_cmd_locked(SEV_CMD_FACTORY_RESET, NULL, &argp->error);
 }
 
 static int sev_ioctl_do_platform_status(struct sev_issue_cmd *argp)
@@ -299,7 +299,7 @@ static int sev_ioctl_do_pek_pdh_gen(int cmd, struct sev_issue_cmd *argp)
 			return rc;
 	}
 
-	return __sev_do_cmd_locked(cmd, 0, &argp->error);
+	return __sev_do_cmd_locked(cmd, NULL, &argp->error);
 }
 
 static int sev_ioctl_do_pek_csr(struct sev_issue_cmd *argp)
@@ -624,7 +624,7 @@ EXPORT_SYMBOL_GPL(sev_guest_decommission);
 
 int sev_guest_df_flush(int *error)
 {
-	return sev_do_cmd(SEV_CMD_DF_FLUSH, 0, error);
+	return sev_do_cmd(SEV_CMD_DF_FLUSH, NULL, error);
 }
 EXPORT_SYMBOL_GPL(sev_guest_df_flush);
 
diff --git a/include/linux/kvm_host.h b/include/linux/kvm_host.h
index ac0062b74aed048923c9f08b4a9d74d176ff6c96..6930c63126c78a9ef665b5b5653a60a8773b4d4c 100644
--- a/include/linux/kvm_host.h
+++ b/include/linux/kvm_host.h
@@ -1105,7 +1105,6 @@ static inline void kvm_irq_routing_update(struct kvm *kvm)
 {
 }
 #endif
-void kvm_arch_irq_routing_update(struct kvm *kvm);
 
 static inline int kvm_ioeventfd(struct kvm *kvm, struct kvm_ioeventfd *args)
 {
@@ -1114,6 +1113,8 @@ static inline int kvm_ioeventfd(struct kvm *kvm, struct kvm_ioeventfd *args)
 
 #endif /* CONFIG_HAVE_KVM_EVENTFD */
 
+void kvm_arch_irq_routing_update(struct kvm *kvm);
+
 static inline void kvm_make_request(int req, struct kvm_vcpu *vcpu)
 {
 	/*
@@ -1272,4 +1273,7 @@ static inline long kvm_arch_vcpu_async_ioctl(struct file *filp,
 }
 #endif /* CONFIG_HAVE_KVM_VCPU_ASYNC_IOCTL */
 
+void kvm_arch_mmu_notifier_invalidate_range(struct kvm *kvm,
+		unsigned long start, unsigned long end);
+
 #endif
diff --git a/include/uapi/linux/psp-sev.h b/include/uapi/linux/psp-sev.h
index 3d77fe91239a802367634d3cc29fd3903c1cee19..9008f31c7eb65c90a487d99d3f371a00ba526f3f 100644
--- a/include/uapi/linux/psp-sev.h
+++ b/include/uapi/linux/psp-sev.h
@@ -42,7 +42,7 @@ typedef enum {
 	SEV_RET_INVALID_PLATFORM_STATE,
 	SEV_RET_INVALID_GUEST_STATE,
 	SEV_RET_INAVLID_CONFIG,
-	SEV_RET_INVALID_len,
+	SEV_RET_INVALID_LEN,
 	SEV_RET_ALREADY_OWNED,
 	SEV_RET_INVALID_CERTIFICATE,
 	SEV_RET_POLICY_FAILURE,
diff --git a/tools/kvm/kvm_stat/kvm_stat b/tools/kvm/kvm_stat/kvm_stat
index a5684d0968b4fd087905e659c0ce80bd170434c2..5898c22ba310bdf27f626a2dfb7bc45e5cf2a505 100755
--- a/tools/kvm/kvm_stat/kvm_stat
+++ b/tools/kvm/kvm_stat/kvm_stat
@@ -33,7 +33,7 @@ import resource
 import struct
 import re
 import subprocess
-from collections import defaultdict
+from collections import defaultdict, namedtuple
 
 VMX_EXIT_REASONS = {
     'EXCEPTION_NMI':        0,
@@ -228,6 +228,7 @@ IOCTL_NUMBERS = {
 }
 
 ENCODING = locale.getpreferredencoding(False)
+TRACE_FILTER = re.compile(r'^[^\(]*$')
 
 
 class Arch(object):
@@ -260,6 +261,11 @@ class Arch(object):
                     return ArchX86(SVM_EXIT_REASONS)
                 return
 
+    def tracepoint_is_child(self, field):
+        if (TRACE_FILTER.match(field)):
+            return None
+        return field.split('(', 1)[0]
+
 
 class ArchX86(Arch):
     def __init__(self, exit_reasons):
@@ -267,6 +273,10 @@ class ArchX86(Arch):
         self.ioctl_numbers = IOCTL_NUMBERS
         self.exit_reasons = exit_reasons
 
+    def debugfs_is_child(self, field):
+        """ Returns name of parent if 'field' is a child, None otherwise """
+        return None
+
 
 class ArchPPC(Arch):
     def __init__(self):
@@ -282,6 +292,10 @@ class ArchPPC(Arch):
         self.ioctl_numbers['SET_FILTER'] = 0x80002406 | char_ptr_size << 16
         self.exit_reasons = {}
 
+    def debugfs_is_child(self, field):
+        """ Returns name of parent if 'field' is a child, None otherwise """
+        return None
+
 
 class ArchA64(Arch):
     def __init__(self):
@@ -289,6 +303,10 @@ class ArchA64(Arch):
         self.ioctl_numbers = IOCTL_NUMBERS
         self.exit_reasons = AARCH64_EXIT_REASONS
 
+    def debugfs_is_child(self, field):
+        """ Returns name of parent if 'field' is a child, None otherwise """
+        return None
+
 
 class ArchS390(Arch):
     def __init__(self):
@@ -296,6 +314,12 @@ class ArchS390(Arch):
         self.ioctl_numbers = IOCTL_NUMBERS
         self.exit_reasons = None
 
+    def debugfs_is_child(self, field):
+        """ Returns name of parent if 'field' is a child, None otherwise """
+        if field.startswith('instruction_'):
+            return 'exit_instruction'
+
+
 ARCH = Arch.get_arch()
 
 
@@ -331,9 +355,6 @@ class perf_event_attr(ctypes.Structure):
 PERF_TYPE_TRACEPOINT = 2
 PERF_FORMAT_GROUP = 1 << 3
 
-PATH_DEBUGFS_TRACING = '/sys/kernel/debug/tracing'
-PATH_DEBUGFS_KVM = '/sys/kernel/debug/kvm'
-
 
 class Group(object):
     """Represents a perf event group."""
@@ -376,8 +397,8 @@ class Event(object):
         self.syscall = self.libc.syscall
         self.name = name
         self.fd = None
-        self.setup_event(group, trace_cpu, trace_pid, trace_point,
-                         trace_filter, trace_set)
+        self._setup_event(group, trace_cpu, trace_pid, trace_point,
+                          trace_filter, trace_set)
 
     def __del__(self):
         """Closes the event's file descriptor.
@@ -390,7 +411,7 @@ class Event(object):
         if self.fd:
             os.close(self.fd)
 
-    def perf_event_open(self, attr, pid, cpu, group_fd, flags):
+    def _perf_event_open(self, attr, pid, cpu, group_fd, flags):
         """Wrapper for the sys_perf_evt_open() syscall.
 
         Used to set up performance events, returns a file descriptor or -1
@@ -409,7 +430,7 @@ class Event(object):
                             ctypes.c_int(pid), ctypes.c_int(cpu),
                             ctypes.c_int(group_fd), ctypes.c_long(flags))
 
-    def setup_event_attribute(self, trace_set, trace_point):
+    def _setup_event_attribute(self, trace_set, trace_point):
         """Returns an initialized ctype perf_event_attr struct."""
 
         id_path = os.path.join(PATH_DEBUGFS_TRACING, 'events', trace_set,
@@ -419,8 +440,8 @@ class Event(object):
         event_attr.config = int(open(id_path).read())
         return event_attr
 
-    def setup_event(self, group, trace_cpu, trace_pid, trace_point,
-                    trace_filter, trace_set):
+    def _setup_event(self, group, trace_cpu, trace_pid, trace_point,
+                     trace_filter, trace_set):
         """Sets up the perf event in Linux.
 
         Issues the syscall to register the event in the kernel and
@@ -428,7 +449,7 @@ class Event(object):
 
         """
 
-        event_attr = self.setup_event_attribute(trace_set, trace_point)
+        event_attr = self._setup_event_attribute(trace_set, trace_point)
 
         # First event will be group leader.
         group_leader = -1
@@ -437,8 +458,8 @@ class Event(object):
         if group.events:
             group_leader = group.events[0].fd
 
-        fd = self.perf_event_open(event_attr, trace_pid,
-                                  trace_cpu, group_leader, 0)
+        fd = self._perf_event_open(event_attr, trace_pid,
+                                   trace_cpu, group_leader, 0)
         if fd == -1:
             err = ctypes.get_errno()
             raise OSError(err, os.strerror(err),
@@ -475,6 +496,10 @@ class Event(object):
 
 class Provider(object):
     """Encapsulates functionalities used by all providers."""
+    def __init__(self, pid):
+        self.child_events = False
+        self.pid = pid
+
     @staticmethod
     def is_field_wanted(fields_filter, field):
         """Indicate whether field is valid according to fields_filter."""
@@ -500,12 +525,12 @@ class TracepointProvider(Provider):
     """
     def __init__(self, pid, fields_filter):
         self.group_leaders = []
-        self.filters = self.get_filters()
+        self.filters = self._get_filters()
         self.update_fields(fields_filter)
-        self.pid = pid
+        super(TracepointProvider, self).__init__(pid)
 
     @staticmethod
-    def get_filters():
+    def _get_filters():
         """Returns a dict of trace events, their filter ids and
         the values that can be filtered.
 
@@ -521,8 +546,8 @@ class TracepointProvider(Provider):
             filters['kvm_exit'] = ('exit_reason', ARCH.exit_reasons)
         return filters
 
-    def get_available_fields(self):
-        """Returns a list of available event's of format 'event name(filter
+    def _get_available_fields(self):
+        """Returns a list of available events of format 'event name(filter
         name)'.
 
         All available events have directories under
@@ -549,11 +574,12 @@ class TracepointProvider(Provider):
 
     def update_fields(self, fields_filter):
         """Refresh fields, applying fields_filter"""
-        self.fields = [field for field in self.get_available_fields()
-                       if self.is_field_wanted(fields_filter, field)]
+        self.fields = [field for field in self._get_available_fields()
+                       if self.is_field_wanted(fields_filter, field) or
+                       ARCH.tracepoint_is_child(field)]
 
     @staticmethod
-    def get_online_cpus():
+    def _get_online_cpus():
         """Returns a list of cpu id integers."""
         def parse_int_list(list_string):
             """Returns an int list from a string of comma separated integers and
@@ -575,17 +601,17 @@ class TracepointProvider(Provider):
             cpu_string = cpu_list.readline()
             return parse_int_list(cpu_string)
 
-    def setup_traces(self):
+    def _setup_traces(self):
         """Creates all event and group objects needed to be able to retrieve
         data."""
-        fields = self.get_available_fields()
+        fields = self._get_available_fields()
         if self._pid > 0:
             # Fetch list of all threads of the monitored pid, as qemu
             # starts a thread for each vcpu.
             path = os.path.join('/proc', str(self._pid), 'task')
             groupids = self.walkdir(path)[1]
         else:
-            groupids = self.get_online_cpus()
+            groupids = self._get_online_cpus()
 
         # The constant is needed as a buffer for python libs, std
         # streams and other files that the script opens.
@@ -663,7 +689,7 @@ class TracepointProvider(Provider):
         # The garbage collector will get rid of all Event/Group
         # objects and open files after removing the references.
         self.group_leaders = []
-        self.setup_traces()
+        self._setup_traces()
         self.fields = self._fields
 
     def read(self, by_guest=0):
@@ -671,8 +697,12 @@ class TracepointProvider(Provider):
         ret = defaultdict(int)
         for group in self.group_leaders:
             for name, val in group.read().items():
-                if name in self._fields:
-                    ret[name] += val
+                if name not in self._fields:
+                    continue
+                parent = ARCH.tracepoint_is_child(name)
+                if parent:
+                    name += ' ' + parent
+                ret[name] += val
         return ret
 
     def reset(self):
@@ -690,11 +720,11 @@ class DebugfsProvider(Provider):
         self._baseline = {}
         self.do_read = True
         self.paths = []
-        self.pid = pid
+        super(DebugfsProvider, self).__init__(pid)
         if include_past:
-            self.restore()
+            self._restore()
 
-    def get_available_fields(self):
+    def _get_available_fields(self):
         """"Returns a list of available fields.
 
         The fields are all available KVM debugfs files
@@ -704,8 +734,9 @@ class DebugfsProvider(Provider):
 
     def update_fields(self, fields_filter):
         """Refresh fields, applying fields_filter"""
-        self._fields = [field for field in self.get_available_fields()
-                        if self.is_field_wanted(fields_filter, field)]
+        self._fields = [field for field in self._get_available_fields()
+                        if self.is_field_wanted(fields_filter, field) or
+                        ARCH.debugfs_is_child(field)]
 
     @property
     def fields(self):
@@ -758,7 +789,7 @@ class DebugfsProvider(Provider):
                     paths.append(dir)
         for path in paths:
             for field in self._fields:
-                value = self.read_field(field, path)
+                value = self._read_field(field, path)
                 key = path + field
                 if reset == 1:
                     self._baseline[key] = value
@@ -766,20 +797,21 @@ class DebugfsProvider(Provider):
                     self._baseline[key] = 0
                 if self._baseline.get(key, -1) == -1:
                     self._baseline[key] = value
-                increment = (results.get(field, 0) + value -
-                             self._baseline.get(key, 0))
-                if by_guest:
-                    pid = key.split('-')[0]
-                    if pid in results:
-                        results[pid] += increment
-                    else:
-                        results[pid] = increment
+                parent = ARCH.debugfs_is_child(field)
+                if parent:
+                    field = field + ' ' + parent
+                else:
+                    if by_guest:
+                        field = key.split('-')[0]    # set 'field' to 'pid'
+                increment = value - self._baseline.get(key, 0)
+                if field in results:
+                    results[field] += increment
                 else:
                     results[field] = increment
 
         return results
 
-    def read_field(self, field, path):
+    def _read_field(self, field, path):
         """Returns the value of a single field from a specific VM."""
         try:
             return int(open(os.path.join(PATH_DEBUGFS_KVM,
@@ -794,12 +826,15 @@ class DebugfsProvider(Provider):
         self._baseline = {}
         self.read(1)
 
-    def restore(self):
+    def _restore(self):
         """Reset field counters"""
         self._baseline = {}
         self.read(2)
 
 
+EventStat = namedtuple('EventStat', ['value', 'delta'])
+
+
 class Stats(object):
     """Manages the data providers and the data they provide.
 
@@ -808,13 +843,13 @@ class Stats(object):
 
     """
     def __init__(self, options):
-        self.providers = self.get_providers(options)
+        self.providers = self._get_providers(options)
         self._pid_filter = options.pid
         self._fields_filter = options.fields
         self.values = {}
+        self._child_events = False
 
-    @staticmethod
-    def get_providers(options):
+    def _get_providers(self, options):
         """Returns a list of data providers depending on the passed options."""
         providers = []
 
@@ -826,7 +861,7 @@ class Stats(object):
 
         return providers
 
-    def update_provider_filters(self):
+    def _update_provider_filters(self):
         """Propagates fields filters to providers."""
         # As we reset the counters when updating the fields we can
         # also clear the cache of old values.
@@ -847,7 +882,7 @@ class Stats(object):
     def fields_filter(self, fields_filter):
         if fields_filter != self._fields_filter:
             self._fields_filter = fields_filter
-            self.update_provider_filters()
+            self._update_provider_filters()
 
     @property
     def pid_filter(self):
@@ -861,16 +896,33 @@ class Stats(object):
             for provider in self.providers:
                 provider.pid = self._pid_filter
 
+    @property
+    def child_events(self):
+        return self._child_events
+
+    @child_events.setter
+    def child_events(self, val):
+        self._child_events = val
+        for provider in self.providers:
+            provider.child_events = val
+
     def get(self, by_guest=0):
         """Returns a dict with field -> (value, delta to last value) of all
-        provider data."""
+        provider data.
+        Key formats:
+          * plain: 'key' is event name
+          * child-parent: 'key' is in format '<child> <parent>'
+          * pid: 'key' is the pid of the guest, and the record contains the
+               aggregated event data
+        These formats are generated by the providers, and handled in class TUI.
+        """
         for provider in self.providers:
             new = provider.read(by_guest=by_guest)
-            for key in new if by_guest else provider.fields:
-                oldval = self.values.get(key, (0, 0))[0]
+            for key in new:
+                oldval = self.values.get(key, EventStat(0, 0)).value
                 newval = new.get(key, 0)
                 newdelta = newval - oldval
-                self.values[key] = (newval, newdelta)
+                self.values[key] = EventStat(newval, newdelta)
         return self.values
 
     def toggle_display_guests(self, to_pid):
@@ -899,10 +951,10 @@ class Stats(object):
         self.get(to_pid)
         return 0
 
+
 DELAY_DEFAULT = 3.0
 MAX_GUEST_NAME_LEN = 48
 MAX_REGEX_LEN = 44
-DEFAULT_REGEX = r'^[^\(]*$'
 SORT_DEFAULT = 0
 
 
@@ -969,7 +1021,7 @@ class Tui(object):
 
         return res
 
-    def print_all_gnames(self, row):
+    def _print_all_gnames(self, row):
         """Print a list of all running guests along with their pids."""
         self.screen.addstr(row, 2, '%8s  %-60s' %
                            ('Pid', 'Guest Name (fuzzy list, might be '
@@ -1032,19 +1084,13 @@ class Tui(object):
 
         return name
 
-    def update_drilldown(self):
-        """Sets or removes a filter that only allows fields without braces."""
-        if not self.stats.fields_filter:
-            self.stats.fields_filter = DEFAULT_REGEX
-
-        elif self.stats.fields_filter == DEFAULT_REGEX:
-            self.stats.fields_filter = None
-
-    def update_pid(self, pid):
+    def _update_pid(self, pid):
         """Propagates pid selection to stats object."""
+        self.screen.addstr(4, 1, 'Updating pid filter...')
+        self.screen.refresh()
         self.stats.pid_filter = pid
 
-    def refresh_header(self, pid=None):
+    def _refresh_header(self, pid=None):
         """Refreshes the header."""
         if pid is None:
             pid = self.stats.pid_filter
@@ -1059,8 +1105,7 @@ class Tui(object):
                                .format(pid, gname), curses.A_BOLD)
         else:
             self.screen.addstr(0, 0, 'kvm statistics - summary', curses.A_BOLD)
-        if self.stats.fields_filter and self.stats.fields_filter \
-           != DEFAULT_REGEX:
+        if self.stats.fields_filter:
             regex = self.stats.fields_filter
             if len(regex) > MAX_REGEX_LEN:
                 regex = regex[:MAX_REGEX_LEN] + '...'
@@ -1075,56 +1120,99 @@ class Tui(object):
         self.screen.addstr(4, 1, 'Collecting data...')
         self.screen.refresh()
 
-    def refresh_body(self, sleeptime):
+    def _refresh_body(self, sleeptime):
+        def is_child_field(field):
+            return field.find('(') != -1
+
+        def insert_child(sorted_items, child, values, parent):
+            num = len(sorted_items)
+            for i in range(0, num):
+                # only add child if parent is present
+                if parent.startswith(sorted_items[i][0]):
+                    sorted_items.insert(i + 1, ('  ' + child, values))
+
+        def get_sorted_events(self, stats):
+            """ separate parent and child events """
+            if self._sorting == SORT_DEFAULT:
+                def sortkey((_k, v)):
+                    # sort by (delta value, overall value)
+                    return (v.delta, v.value)
+            else:
+                def sortkey((_k, v)):
+                    # sort by overall value
+                    return v.value
+
+            childs = []
+            sorted_items = []
+            # we can't rule out child events to appear prior to parents even
+            # when sorted - separate out all children first, and add in later
+            for key, values in sorted(stats.items(), key=sortkey,
+                                      reverse=True):
+                if values == (0, 0):
+                    continue
+                if key.find(' ') != -1:
+                    if not self.stats.child_events:
+                        continue
+                    childs.insert(0, (key, values))
+                else:
+                    sorted_items.append((key, values))
+            if self.stats.child_events:
+                for key, values in childs:
+                    (child, parent) = key.split(' ')
+                    insert_child(sorted_items, child, values, parent)
+
+            return sorted_items
+
         row = 3
         self.screen.move(row, 0)
         self.screen.clrtobot()
         stats = self.stats.get(self._display_guests)
-
-        def sortCurAvg(x):
-            # sort by current events if available
-            if stats[x][1]:
-                return (-stats[x][1], -stats[x][0])
+        total = 0.
+        ctotal = 0.
+        for key, values in stats.items():
+            if self._display_guests:
+                if self.get_gname_from_pid(key):
+                    total += values.value
+                continue
+            if not key.find(' ') != -1:
+                total += values.value
             else:
-                return (0, -stats[x][0])
+                ctotal += values.value
+        if total == 0.:
+            # we don't have any fields, or all non-child events are filtered
+            total = ctotal
 
-        def sortTotal(x):
-            # sort by totals
-            return (0, -stats[x][0])
-        total = 0.
-        for key in stats.keys():
-            if key.find('(') is -1:
-                total += stats[key][0]
-        if self._sorting == SORT_DEFAULT:
-            sortkey = sortCurAvg
-        else:
-            sortkey = sortTotal
+        # print events
         tavg = 0
-        for key in sorted(stats.keys(), key=sortkey):
-            if row >= self.screen.getmaxyx()[0] - 1:
-                break
-            values = stats[key]
-            if not values[0] and not values[1]:
+        tcur = 0
+        for key, values in get_sorted_events(self, stats):
+            if row >= self.screen.getmaxyx()[0] - 1 or values == (0, 0):
                 break
-            if values[0] is not None:
-                cur = int(round(values[1] / sleeptime)) if values[1] else ''
-                if self._display_guests:
-                    key = self.get_gname_from_pid(key)
-                self.screen.addstr(row, 1, '%-40s %10d%7.1f %8s' %
-                                   (key, values[0], values[0] * 100 / total,
-                                    cur))
-                if cur is not '' and key.find('(') is -1:
-                    tavg += cur
+            if self._display_guests:
+                key = self.get_gname_from_pid(key)
+                if not key:
+                    continue
+            cur = int(round(values.delta / sleeptime)) if values.delta else ''
+            if key[0] != ' ':
+                if values.delta:
+                    tcur += values.delta
+                ptotal = values.value
+                ltotal = total
+            else:
+                ltotal = ptotal
+            self.screen.addstr(row, 1, '%-40s %10d%7.1f %8s' % (key,
+                               values.value,
+                               values.value * 100 / float(ltotal), cur))
             row += 1
         if row == 3:
             self.screen.addstr(4, 1, 'No matching events reported yet')
-        else:
+        if row > 4:
+            tavg = int(round(tcur / sleeptime)) if tcur > 0 else ''
             self.screen.addstr(row, 1, '%-40s %10d        %8s' %
-                               ('Total', total, tavg if tavg else ''),
-                               curses.A_BOLD)
+                               ('Total', total, tavg), curses.A_BOLD)
         self.screen.refresh()
 
-    def show_msg(self, text):
+    def _show_msg(self, text):
         """Display message centered text and exit on key press"""
         hint = 'Press any key to continue'
         curses.cbreak()
@@ -1139,16 +1227,16 @@ class Tui(object):
                            curses.A_STANDOUT)
         self.screen.getkey()
 
-    def show_help_interactive(self):
+    def _show_help_interactive(self):
         """Display help with list of interactive commands"""
         msg = ('   b     toggle events by guests (debugfs only, honors'
                ' filters)',
                '   c     clear filter',
                '   f     filter by regular expression',
-               '   g     filter by guest name',
+               '   g     filter by guest name/PID',
                '   h     display interactive commands reference',
                '   o     toggle sorting order (Total vs CurAvg/s)',
-               '   p     filter by PID',
+               '   p     filter by guest name/PID',
                '   q     quit',
                '   r     reset stats',
                '   s     set update interval',
@@ -1165,14 +1253,15 @@ class Tui(object):
             self.screen.addstr(row, 0, line)
             row += 1
         self.screen.getkey()
-        self.refresh_header()
+        self._refresh_header()
 
-    def show_filter_selection(self):
+    def _show_filter_selection(self):
         """Draws filter selection mask.
 
         Asks for a valid regex and sets the fields filter accordingly.
 
         """
+        msg = ''
         while True:
             self.screen.erase()
             self.screen.addstr(0, 0,
@@ -1181,61 +1270,25 @@ class Tui(object):
             self.screen.addstr(2, 0,
                                "Current regex: {0}"
                                .format(self.stats.fields_filter))
+            self.screen.addstr(5, 0, msg)
             self.screen.addstr(3, 0, "New regex: ")
             curses.echo()
             regex = self.screen.getstr().decode(ENCODING)
             curses.noecho()
             if len(regex) == 0:
-                self.stats.fields_filter = DEFAULT_REGEX
-                self.refresh_header()
+                self.stats.fields_filter = ''
+                self._refresh_header()
                 return
             try:
                 re.compile(regex)
                 self.stats.fields_filter = regex
-                self.refresh_header()
+                self._refresh_header()
                 return
             except re.error:
+                msg = '"' + regex + '": Not a valid regular expression'
                 continue
 
-    def show_vm_selection_by_pid(self):
-        """Draws PID selection mask.
-
-        Asks for a pid until a valid pid or 0 has been entered.
-
-        """
-        msg = ''
-        while True:
-            self.screen.erase()
-            self.screen.addstr(0, 0,
-                               'Show statistics for specific pid.',
-                               curses.A_BOLD)
-            self.screen.addstr(1, 0,
-                               'This might limit the shown data to the trace '
-                               'statistics.')
-            self.screen.addstr(5, 0, msg)
-            self.print_all_gnames(7)
-
-            curses.echo()
-            self.screen.addstr(3, 0, "Pid [0 or pid]: ")
-            pid = self.screen.getstr().decode(ENCODING)
-            curses.noecho()
-
-            try:
-                if len(pid) > 0:
-                    pid = int(pid)
-                    if pid != 0 and not os.path.isdir(os.path.join('/proc/',
-                                                                   str(pid))):
-                        msg = '"' + str(pid) + '": Not a running process'
-                        continue
-                else:
-                    pid = 0
-                self.refresh_header(pid)
-                self.update_pid(pid)
-                break
-            except ValueError:
-                msg = '"' + str(pid) + '": Not a valid pid'
-
-    def show_set_update_interval(self):
+    def _show_set_update_interval(self):
         """Draws update interval selection mask."""
         msg = ''
         while True:
@@ -1265,60 +1318,67 @@ class Tui(object):
 
             except ValueError:
                 msg = '"' + str(val) + '": Invalid value'
-        self.refresh_header()
+        self._refresh_header()
 
-    def show_vm_selection_by_guest_name(self):
+    def _show_vm_selection_by_guest(self):
         """Draws guest selection mask.
 
-        Asks for a guest name until a valid guest name or '' is entered.
+        Asks for a guest name or pid until a valid guest name or '' is entered.
 
         """
         msg = ''
         while True:
             self.screen.erase()
             self.screen.addstr(0, 0,
-                               'Show statistics for specific guest.',
+                               'Show statistics for specific guest or pid.',
                                curses.A_BOLD)
             self.screen.addstr(1, 0,
                                'This might limit the shown data to the trace '
                                'statistics.')
             self.screen.addstr(5, 0, msg)
-            self.print_all_gnames(7)
+            self._print_all_gnames(7)
             curses.echo()
-            self.screen.addstr(3, 0, "Guest [ENTER or guest]: ")
-            gname = self.screen.getstr().decode(ENCODING)
+            curses.curs_set(1)
+            self.screen.addstr(3, 0, "Guest or pid [ENTER exits]: ")
+            guest = self.screen.getstr().decode(ENCODING)
             curses.noecho()
 
-            if not gname:
-                self.refresh_header(0)
-                self.update_pid(0)
+            pid = 0
+            if not guest or guest == '0':
                 break
-            else:
-                pids = []
-                try:
-                    pids = self.get_pid_from_gname(gname)
-                except:
-                    msg = '"' + gname + '": Internal error while searching, ' \
-                          'use pid filter instead'
-                    continue
-                if len(pids) == 0:
-                    msg = '"' + gname + '": Not an active guest'
+            if guest.isdigit():
+                if not os.path.isdir(os.path.join('/proc/', guest)):
+                    msg = '"' + guest + '": Not a running process'
                     continue
-                if len(pids) > 1:
-                    msg = '"' + gname + '": Multiple matches found, use pid ' \
-                          'filter instead'
-                    continue
-                self.refresh_header(pids[0])
-                self.update_pid(pids[0])
+                pid = int(guest)
                 break
+            pids = []
+            try:
+                pids = self.get_pid_from_gname(guest)
+            except:
+                msg = '"' + guest + '": Internal error while searching, ' \
+                      'use pid filter instead'
+                continue
+            if len(pids) == 0:
+                msg = '"' + guest + '": Not an active guest'
+                continue
+            if len(pids) > 1:
+                msg = '"' + guest + '": Multiple matches found, use pid ' \
+                      'filter instead'
+                continue
+            pid = pids[0]
+            break
+        curses.curs_set(0)
+        self._refresh_header(pid)
+        self._update_pid(pid)
 
     def show_stats(self):
         """Refreshes the screen and processes user input."""
         sleeptime = self._delay_initial
-        self.refresh_header()
+        self._refresh_header()
         start = 0.0  # result based on init value never appears on screen
         while True:
-            self.refresh_body(time.time() - start)
+            self._refresh_body(time.time() - start)
             curses.halfdelay(int(sleeptime * 10))
             start = time.time()
             sleeptime = self._delay_regular
@@ -1327,47 +1387,39 @@ class Tui(object):
                 if char == 'b':
                     self._display_guests = not self._display_guests
                     if self.stats.toggle_display_guests(self._display_guests):
-                        self.show_msg(['Command not available with tracepoints'
-                                       ' enabled', 'Restart with debugfs only '
-                                       '(see option \'-d\') and try again!'])
+                        self._show_msg(['Command not available with '
+                                        'tracepoints enabled', 'Restart with '
+                                        'debugfs only (see option \'-d\') and '
+                                        'try again!'])
                         self._display_guests = not self._display_guests
-                    self.refresh_header()
+                    self._refresh_header()
                 if char == 'c':
-                    self.stats.fields_filter = DEFAULT_REGEX
-                    self.refresh_header(0)
-                    self.update_pid(0)
+                    self.stats.fields_filter = ''
+                    self._refresh_header(0)
+                    self._update_pid(0)
                 if char == 'f':
                     curses.curs_set(1)
-                    self.show_filter_selection()
+                    self._show_filter_selection()
                     curses.curs_set(0)
                     sleeptime = self._delay_initial
-                if char == 'g':
-                    curses.curs_set(1)
-                    self.show_vm_selection_by_guest_name()
-                    curses.curs_set(0)
+                if char == 'g' or char == 'p':
+                    self._show_vm_selection_by_guest()
                     sleeptime = self._delay_initial
                 if char == 'h':
-                    self.show_help_interactive()
+                    self._show_help_interactive()
                 if char == 'o':
                     self._sorting = not self._sorting
-                if char == 'p':
-                    curses.curs_set(1)
-                    self.show_vm_selection_by_pid()
-                    curses.curs_set(0)
-                    sleeptime = self._delay_initial
                 if char == 'q':
                     break
                 if char == 'r':
                     self.stats.reset()
                 if char == 's':
                     curses.curs_set(1)
-                    self.show_set_update_interval()
+                    self._show_set_update_interval()
                     curses.curs_set(0)
                     sleeptime = self._delay_initial
                 if char == 'x':
-                    self.update_drilldown()
-                    # prevents display of current values on next refresh
-                    self.stats.get(self._display_guests)
+                    self.stats.child_events = not self.stats.child_events
             except KeyboardInterrupt:
                 break
             except curses.error:
@@ -1380,9 +1432,9 @@ def batch(stats):
         s = stats.get()
         time.sleep(1)
         s = stats.get()
-        for key in sorted(s.keys()):
-            values = s[key]
-            print('%-42s%10d%10d' % (key, values[0], values[1]))
+        for key, values in sorted(s.items()):
+            print('%-42s%10d%10d' % (key.split(' ')[0], values.value,
+                  values.delta))
     except KeyboardInterrupt:
         pass
 
@@ -1392,14 +1444,14 @@ def log(stats):
     keys = sorted(stats.get().keys())
 
     def banner():
-        for k in keys:
-            print(k, end=' ')
+        for key in keys:
+            print(key.split(' ')[0], end=' ')
         print()
 
     def statline():
         s = stats.get()
-        for k in keys:
-            print(' %9d' % s[k][1], end=' ')
+        for key in keys:
+            print(' %9d' % s[key].delta, end=' ')
         print()
     line = 0
     banner_repeat = 20
@@ -1504,7 +1556,7 @@ Press any other key to refresh statistics immediately.
                          )
     optparser.add_option('-f', '--fields',
                          action='store',
-                         default=DEFAULT_REGEX,
+                         default='',
                          dest='fields',
                          help='''fields to display (regex)
                                  "-f help" for a list of available events''',
@@ -1539,17 +1591,6 @@ Press any other key to refresh statistics immediately.
 
 def check_access(options):
     """Exits if the current user can't access all needed directories."""
-    if not os.path.exists('/sys/kernel/debug'):
-        sys.stderr.write('Please enable CONFIG_DEBUG_FS in your kernel.')
-        sys.exit(1)
-
-    if not os.path.exists(PATH_DEBUGFS_KVM):
-        sys.stderr.write("Please make sure, that debugfs is mounted and "
-                         "readable by the current user:\n"
-                         "('mount -t debugfs debugfs /sys/kernel/debug')\n"
-                         "Also ensure, that the kvm modules are loaded.\n")
-        sys.exit(1)
-
     if not os.path.exists(PATH_DEBUGFS_TRACING) and (options.tracepoints or
                                                      not options.debugfs):
         sys.stderr.write("Please enable CONFIG_TRACING in your kernel "
@@ -1567,7 +1608,33 @@ def check_access(options):
     return options
 
 
+def assign_globals():
+    global PATH_DEBUGFS_KVM
+    global PATH_DEBUGFS_TRACING
+
+    debugfs = ''
+    for line in file('/proc/mounts'):
+        if line.split(' ')[0] == 'debugfs':
+            debugfs = line.split(' ')[1]
+            break
+    if debugfs == '':
+        sys.stderr.write("Please make sure that CONFIG_DEBUG_FS is enabled in "
+                         "your kernel, mounted and\nreadable by the current "
+                         "user:\n"
+                         "('mount -t debugfs debugfs /sys/kernel/debug')\n")
+        sys.exit(1)
+
+    PATH_DEBUGFS_KVM = os.path.join(debugfs, 'kvm')
+    PATH_DEBUGFS_TRACING = os.path.join(debugfs, 'tracing')
+
+    if not os.path.exists(PATH_DEBUGFS_KVM):
+        sys.stderr.write("Please make sure that CONFIG_KVM is enabled in "
+                         "your kernel and that the modules are loaded.\n")
+        sys.exit(1)
+
+
 def main():
+    assign_globals()
     options = get_options()
     options = check_access(options)
 
diff --git a/tools/kvm/kvm_stat/kvm_stat.txt b/tools/kvm/kvm_stat/kvm_stat.txt
index b5b3810c9e945d7f3a39568840fbc5b73f84983b..0811d860fe75000de852c569ccaaa3ae1f0aa944 100644
--- a/tools/kvm/kvm_stat/kvm_stat.txt
+++ b/tools/kvm/kvm_stat/kvm_stat.txt
@@ -35,13 +35,13 @@ INTERACTIVE COMMANDS
 
 *f*::	filter by regular expression
 
-*g*::	filter by guest name
+*g*::	filter by guest name/PID
 
 *h*::	display interactive commands reference
 
 *o*::   toggle sorting order (Total vs CurAvg/s)
 
-*p*::	filter by PID
+*p*::	filter by guest name/PID
 
 *q*::	quit
 
diff --git a/virt/kvm/arm/arch_timer.c b/virt/kvm/arm/arch_timer.c
index 70268c0bec799c0ce85c2f27267e0ab514c5f852..70f4c30918eb2682b638c326977b50e3568c7638 100644
--- a/virt/kvm/arm/arch_timer.c
+++ b/virt/kvm/arm/arch_timer.c
@@ -36,6 +36,8 @@ static struct timecounter *timecounter;
 static unsigned int host_vtimer_irq;
 static u32 host_vtimer_irq_flags;
 
+static DEFINE_STATIC_KEY_FALSE(has_gic_active_state);
+
 static const struct kvm_irq_level default_ptimer_irq = {
 	.irq	= 30,
 	.level	= 1,
@@ -56,6 +58,12 @@ u64 kvm_phys_timer_read(void)
 	return timecounter->cc->read(timecounter->cc);
 }
 
+static inline bool userspace_irqchip(struct kvm *kvm)
+{
+	return static_branch_unlikely(&userspace_irqchip_in_use) &&
+		unlikely(!irqchip_in_kernel(kvm));
+}
+
 static void soft_timer_start(struct hrtimer *hrt, u64 ns)
 {
 	hrtimer_start(hrt, ktime_add_ns(ktime_get(), ns),
@@ -69,25 +77,6 @@ static void soft_timer_cancel(struct hrtimer *hrt, struct work_struct *work)
 		cancel_work_sync(work);
 }
 
-static void kvm_vtimer_update_mask_user(struct kvm_vcpu *vcpu)
-{
-	struct arch_timer_context *vtimer = vcpu_vtimer(vcpu);
-
-	/*
-	 * When using a userspace irqchip with the architected timers, we must
-	 * prevent continuously exiting from the guest, and therefore mask the
-	 * physical interrupt by disabling it on the host interrupt controller
-	 * when the virtual level is high, such that the guest can make
-	 * forward progress.  Once we detect the output level being
-	 * de-asserted, we unmask the interrupt again so that we exit from the
-	 * guest when the timer fires.
-	 */
-	if (vtimer->irq.level)
-		disable_percpu_irq(host_vtimer_irq);
-	else
-		enable_percpu_irq(host_vtimer_irq, 0);
-}
-
 static irqreturn_t kvm_arch_timer_handler(int irq, void *dev_id)
 {
 	struct kvm_vcpu *vcpu = *(struct kvm_vcpu **)dev_id;
@@ -106,9 +95,9 @@ static irqreturn_t kvm_arch_timer_handler(int irq, void *dev_id)
 	if (kvm_timer_should_fire(vtimer))
 		kvm_timer_update_irq(vcpu, true, vtimer);
 
-	if (static_branch_unlikely(&userspace_irqchip_in_use) &&
-	    unlikely(!irqchip_in_kernel(vcpu->kvm)))
-		kvm_vtimer_update_mask_user(vcpu);
+	if (userspace_irqchip(vcpu->kvm) &&
+	    !static_branch_unlikely(&has_gic_active_state))
+		disable_percpu_irq(host_vtimer_irq);
 
 	return IRQ_HANDLED;
 }
@@ -290,8 +279,7 @@ static void kvm_timer_update_irq(struct kvm_vcpu *vcpu, bool new_level,
 	trace_kvm_timer_update_irq(vcpu->vcpu_id, timer_ctx->irq.irq,
 				   timer_ctx->irq.level);
 
-	if (!static_branch_unlikely(&userspace_irqchip_in_use) ||
-	    likely(irqchip_in_kernel(vcpu->kvm))) {
+	if (!userspace_irqchip(vcpu->kvm)) {
 		ret = kvm_vgic_inject_irq(vcpu->kvm, vcpu->vcpu_id,
 					  timer_ctx->irq.irq,
 					  timer_ctx->irq.level,
@@ -350,12 +338,6 @@ static void kvm_timer_update_state(struct kvm_vcpu *vcpu)
 	phys_timer_emulate(vcpu);
 }
 
-static void __timer_snapshot_state(struct arch_timer_context *timer)
-{
-	timer->cnt_ctl = read_sysreg_el0(cntv_ctl);
-	timer->cnt_cval = read_sysreg_el0(cntv_cval);
-}
-
 static void vtimer_save_state(struct kvm_vcpu *vcpu)
 {
 	struct arch_timer_cpu *timer = &vcpu->arch.timer_cpu;
@@ -367,8 +349,10 @@ static void vtimer_save_state(struct kvm_vcpu *vcpu)
 	if (!vtimer->loaded)
 		goto out;
 
-	if (timer->enabled)
-		__timer_snapshot_state(vtimer);
+	if (timer->enabled) {
+		vtimer->cnt_ctl = read_sysreg_el0(cntv_ctl);
+		vtimer->cnt_cval = read_sysreg_el0(cntv_cval);
+	}
 
 	/* Disable the virtual timer */
 	write_sysreg_el0(0, cntv_ctl);
@@ -460,23 +444,43 @@ static void set_cntvoff(u64 cntvoff)
 	kvm_call_hyp(__kvm_timer_set_cntvoff, low, high);
 }
 
-static void kvm_timer_vcpu_load_vgic(struct kvm_vcpu *vcpu)
+static inline void set_vtimer_irq_phys_active(struct kvm_vcpu *vcpu, bool active)
+{
+	int r;
+	r = irq_set_irqchip_state(host_vtimer_irq, IRQCHIP_STATE_ACTIVE, active);
+	WARN_ON(r);
+}
+
+static void kvm_timer_vcpu_load_gic(struct kvm_vcpu *vcpu)
 {
 	struct arch_timer_context *vtimer = vcpu_vtimer(vcpu);
 	bool phys_active;
-	int ret;
 
-	phys_active = kvm_vgic_map_is_active(vcpu, vtimer->irq.irq);
-
-	ret = irq_set_irqchip_state(host_vtimer_irq,
-				    IRQCHIP_STATE_ACTIVE,
-				    phys_active);
-	WARN_ON(ret);
+	if (irqchip_in_kernel(vcpu->kvm))
+		phys_active = kvm_vgic_map_is_active(vcpu, vtimer->irq.irq);
+	else
+		phys_active = vtimer->irq.level;
+	set_vtimer_irq_phys_active(vcpu, phys_active);
 }
 
-static void kvm_timer_vcpu_load_user(struct kvm_vcpu *vcpu)
+static void kvm_timer_vcpu_load_nogic(struct kvm_vcpu *vcpu)
 {
-	kvm_vtimer_update_mask_user(vcpu);
+	struct arch_timer_context *vtimer = vcpu_vtimer(vcpu);
+
+	/*
+	 * When using a userspace irqchip with the architected timers and a
+	 * host interrupt controller that doesn't support an active state, we
+	 * must still prevent continuously exiting from the guest, and
+	 * therefore mask the physical interrupt by disabling it on the host
+	 * interrupt controller when the virtual level is high, such that the
+	 * guest can make forward progress.  Once we detect the output level
+	 * being de-asserted, we unmask the interrupt again so that we exit
+	 * from the guest when the timer fires.
+	 */
+	if (vtimer->irq.level)
+		disable_percpu_irq(host_vtimer_irq);
+	else
+		enable_percpu_irq(host_vtimer_irq, host_vtimer_irq_flags);
 }
 
 void kvm_timer_vcpu_load(struct kvm_vcpu *vcpu)
@@ -487,10 +491,10 @@ void kvm_timer_vcpu_load(struct kvm_vcpu *vcpu)
 	if (unlikely(!timer->enabled))
 		return;
 
-	if (unlikely(!irqchip_in_kernel(vcpu->kvm)))
-		kvm_timer_vcpu_load_user(vcpu);
+	if (static_branch_likely(&has_gic_active_state))
+		kvm_timer_vcpu_load_gic(vcpu);
 	else
-		kvm_timer_vcpu_load_vgic(vcpu);
+		kvm_timer_vcpu_load_nogic(vcpu);
 
 	set_cntvoff(vtimer->cntvoff);
 
@@ -555,18 +559,24 @@ static void unmask_vtimer_irq_user(struct kvm_vcpu *vcpu)
 {
 	struct arch_timer_context *vtimer = vcpu_vtimer(vcpu);
 
-	if (unlikely(!irqchip_in_kernel(vcpu->kvm))) {
-		__timer_snapshot_state(vtimer);
-		if (!kvm_timer_should_fire(vtimer)) {
-			kvm_timer_update_irq(vcpu, false, vtimer);
-			kvm_vtimer_update_mask_user(vcpu);
-		}
+	if (!kvm_timer_should_fire(vtimer)) {
+		kvm_timer_update_irq(vcpu, false, vtimer);
+		if (static_branch_likely(&has_gic_active_state))
+			set_vtimer_irq_phys_active(vcpu, false);
+		else
+			enable_percpu_irq(host_vtimer_irq, host_vtimer_irq_flags);
 	}
 }
 
 void kvm_timer_sync_hwstate(struct kvm_vcpu *vcpu)
 {
-	unmask_vtimer_irq_user(vcpu);
+	struct arch_timer_cpu *timer = &vcpu->arch.timer_cpu;
+
+	if (unlikely(!timer->enabled))
+		return;
+
+	if (unlikely(!irqchip_in_kernel(vcpu->kvm)))
+		unmask_vtimer_irq_user(vcpu);
 }
 
 int kvm_timer_vcpu_reset(struct kvm_vcpu *vcpu)
@@ -753,6 +763,8 @@ int kvm_timer_hyp_init(bool has_gic)
 			kvm_err("kvm_arch_timer: error setting vcpu affinity\n");
 			goto out_free_irq;
 		}
+
+		static_branch_enable(&has_gic_active_state);
 	}
 
 	kvm_info("virtual timer IRQ%d\n", host_vtimer_irq);
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 4501e658e8d6fc97f39b54bceff8f8b1f36cb95a..65dea3ffef68ede4ca8d8e5bb7d9c2d3edd3352a 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -969,8 +969,7 @@ int __kvm_set_memory_region(struct kvm *kvm,
 		/* Check for overlaps */
 		r = -EEXIST;
 		kvm_for_each_memslot(slot, __kvm_memslots(kvm, as_id)) {
-			if ((slot->id >= KVM_USER_MEM_SLOTS) ||
-			    (slot->id == id))
+			if (slot->id == id)
 				continue;
 			if (!((base_gfn + npages <= slot->base_gfn) ||
 			      (base_gfn >= slot->base_gfn + slot->npages)))