diff --git a/arch/x86/include/asm/traps.h b/arch/x86/include/asm/traps.h
index 707adc6549d82335a20bdf18d18b697fa1fe9eab..3cf525ec762dcc2a68eedbfcf379d83f622b7f44 100644
--- a/arch/x86/include/asm/traps.h
+++ b/arch/x86/include/asm/traps.h
@@ -1,6 +1,7 @@
 #ifndef _ASM_X86_TRAPS_H
 #define _ASM_X86_TRAPS_H
 
+#include <linux/context_tracking_state.h>
 #include <linux/kprobes.h>
 
 #include <asm/debugreg.h>
@@ -110,6 +111,9 @@ asmlinkage void smp_thermal_interrupt(void);
 asmlinkage void mce_threshold_interrupt(void);
 #endif
 
+extern enum ctx_state ist_enter(struct pt_regs *regs);
+extern void ist_exit(struct pt_regs *regs, enum ctx_state prev_state);
+
 /* Interrupts/Exceptions */
 enum {
 	X86_TRAP_DE = 0,	/*  0, Divide-by-zero */
diff --git a/arch/x86/kernel/cpu/mcheck/mce.c b/arch/x86/kernel/cpu/mcheck/mce.c
index d2c611699cd9d2d49bfd1cee5b79c7fedf87ef71..800d423f1e920b33b0dceca0bf939004937d5740 100644
--- a/arch/x86/kernel/cpu/mcheck/mce.c
+++ b/arch/x86/kernel/cpu/mcheck/mce.c
@@ -43,6 +43,7 @@
 #include <linux/export.h>
 
 #include <asm/processor.h>
+#include <asm/traps.h>
 #include <asm/mce.h>
 #include <asm/msr.h>
 
@@ -1063,6 +1064,7 @@ void do_machine_check(struct pt_regs *regs, long error_code)
 {
 	struct mca_config *cfg = &mca_cfg;
 	struct mce m, *final;
+	enum ctx_state prev_state;
 	int i;
 	int worst = 0;
 	int severity;
@@ -1085,6 +1087,8 @@ void do_machine_check(struct pt_regs *regs, long error_code)
 	DECLARE_BITMAP(valid_banks, MAX_NR_BANKS);
 	char *msg = "Unknown";
 
+	prev_state = ist_enter(regs);
+
 	this_cpu_inc(mce_exception_count);
 
 	if (!cfg->banks)
@@ -1216,6 +1220,7 @@ void do_machine_check(struct pt_regs *regs, long error_code)
 	mce_wrmsrl(MSR_IA32_MCG_STATUS, 0);
 out:
 	sync_core();
+	ist_exit(regs, prev_state);
 }
 EXPORT_SYMBOL_GPL(do_machine_check);
 
diff --git a/arch/x86/kernel/cpu/mcheck/p5.c b/arch/x86/kernel/cpu/mcheck/p5.c
index a3042989398c1cdb7d33da49bf7dea1887e00aa5..ec2663a708e40d2e3fa03b36fa7587c860a1871b 100644
--- a/arch/x86/kernel/cpu/mcheck/p5.c
+++ b/arch/x86/kernel/cpu/mcheck/p5.c
@@ -8,6 +8,7 @@
 #include <linux/smp.h>
 
 #include <asm/processor.h>
+#include <asm/traps.h>
 #include <asm/mce.h>
 #include <asm/msr.h>
 
@@ -17,8 +18,11 @@ int mce_p5_enabled __read_mostly;
 /* Machine check handler for Pentium class Intel CPUs: */
 static void pentium_machine_check(struct pt_regs *regs, long error_code)
 {
+	enum ctx_state prev_state;
 	u32 loaddr, hi, lotype;
 
+	prev_state = ist_enter(regs);
+
 	rdmsr(MSR_IA32_P5_MC_ADDR, loaddr, hi);
 	rdmsr(MSR_IA32_P5_MC_TYPE, lotype, hi);
 
@@ -33,6 +37,8 @@ static void pentium_machine_check(struct pt_regs *regs, long error_code)
 	}
 
 	add_taint(TAINT_MACHINE_CHECK, LOCKDEP_NOW_UNRELIABLE);
+
+	ist_exit(regs, prev_state);
 }
 
 /* Set up machine check reporting for processors with Intel style MCE: */
diff --git a/arch/x86/kernel/cpu/mcheck/winchip.c b/arch/x86/kernel/cpu/mcheck/winchip.c
index 7dc5564d0cdf57c0e7ca8c181f87f3ebb6f6ceb2..bd5d46a32210a15deb8895c37e67e5e0dea9d7c5 100644
--- a/arch/x86/kernel/cpu/mcheck/winchip.c
+++ b/arch/x86/kernel/cpu/mcheck/winchip.c
@@ -7,14 +7,19 @@
 #include <linux/types.h>
 
 #include <asm/processor.h>
+#include <asm/traps.h>
 #include <asm/mce.h>
 #include <asm/msr.h>
 
 /* Machine check handler for WinChip C6: */
 static void winchip_machine_check(struct pt_regs *regs, long error_code)
 {
+	enum ctx_state prev_state = ist_enter(regs);
+
 	printk(KERN_EMERG "CPU0: Machine Check Exception.\n");
 	add_taint(TAINT_MACHINE_CHECK, LOCKDEP_NOW_UNRELIABLE);
+
+	ist_exit(regs, prev_state);
 }
 
 /* Set up machine check reporting on the Winchip C6 series */
diff --git a/arch/x86/kernel/traps.c b/arch/x86/kernel/traps.c
index 28f3e5ffc55ddce45e19570fed420c090acfdb6f..b3a9d24dba25cc1c8b276ce3cffb2c8cd83e9fe5 100644
--- a/arch/x86/kernel/traps.c
+++ b/arch/x86/kernel/traps.c
@@ -108,6 +108,39 @@ static inline void preempt_conditional_cli(struct pt_regs *regs)
 	preempt_count_dec();
 }
 
+enum ctx_state ist_enter(struct pt_regs *regs)
+{
+	/*
+	 * We are atomic because we're on the IST stack (or we're on x86_32,
+	 * in which case we still shouldn't schedule.
+	 */
+	preempt_count_add(HARDIRQ_OFFSET);
+
+	if (user_mode_vm(regs)) {
+		/* Other than that, we're just an exception. */
+		return exception_enter();
+	} else {
+		/*
+		 * We might have interrupted pretty much anything.  In
+		 * fact, if we're a machine check, we can even interrupt
+		 * NMI processing.  We don't want in_nmi() to return true,
+		 * but we need to notify RCU.
+		 */
+		rcu_nmi_enter();
+		return IN_KERNEL;  /* the value is irrelevant. */
+	}
+}
+
+void ist_exit(struct pt_regs *regs, enum ctx_state prev_state)
+{
+	preempt_count_sub(HARDIRQ_OFFSET);
+
+	if (user_mode_vm(regs))
+		return exception_exit(prev_state);
+	else
+		rcu_nmi_exit();
+}
+
 static nokprobe_inline int
 do_trap_no_signal(struct task_struct *tsk, int trapnr, char *str,
 		  struct pt_regs *regs,	long error_code)
@@ -251,6 +284,8 @@ dotraplinkage void do_double_fault(struct pt_regs *regs, long error_code)
 	 * end up promoting it to a doublefault.  In that case, modify
 	 * the stack to make it look like we just entered the #GP
 	 * handler from user space, similar to bad_iret.
+	 *
+	 * No need for ist_enter here because we don't use RCU.
 	 */
 	if (((long)regs->sp >> PGDIR_SHIFT) == ESPFIX_PGD_ENTRY &&
 		regs->cs == __KERNEL_CS &&
@@ -263,12 +298,12 @@ dotraplinkage void do_double_fault(struct pt_regs *regs, long error_code)
 		normal_regs->orig_ax = 0;  /* Missing (lost) #GP error code */
 		regs->ip = (unsigned long)general_protection;
 		regs->sp = (unsigned long)&normal_regs->orig_ax;
+
 		return;
 	}
 #endif
 
-	exception_enter();
-	/* Return not checked because double check cannot be ignored */
+	ist_enter(regs);  /* Discard prev_state because we won't return. */
 	notify_die(DIE_TRAP, str, regs, error_code, X86_TRAP_DF, SIGSEGV);
 
 	tsk->thread.error_code = error_code;
@@ -434,7 +469,7 @@ dotraplinkage void notrace do_int3(struct pt_regs *regs, long error_code)
 	if (poke_int3_handler(regs))
 		return;
 
-	prev_state = exception_enter();
+	prev_state = ist_enter(regs);
 #ifdef CONFIG_KGDB_LOW_LEVEL_TRAP
 	if (kgdb_ll_trap(DIE_INT3, "int3", regs, error_code, X86_TRAP_BP,
 				SIGTRAP) == NOTIFY_STOP)
@@ -460,7 +495,7 @@ dotraplinkage void notrace do_int3(struct pt_regs *regs, long error_code)
 	preempt_conditional_cli(regs);
 	debug_stack_usage_dec();
 exit:
-	exception_exit(prev_state);
+	ist_exit(regs, prev_state);
 }
 NOKPROBE_SYMBOL(do_int3);
 
@@ -541,7 +576,7 @@ dotraplinkage void do_debug(struct pt_regs *regs, long error_code)
 	unsigned long dr6;
 	int si_code;
 
-	prev_state = exception_enter();
+	prev_state = ist_enter(regs);
 
 	get_debugreg(dr6, 6);
 
@@ -616,7 +651,7 @@ dotraplinkage void do_debug(struct pt_regs *regs, long error_code)
 	debug_stack_usage_dec();
 
 exit:
-	exception_exit(prev_state);
+	ist_exit(regs, prev_state);
 }
 NOKPROBE_SYMBOL(do_debug);