From c5b8eae77e67c47e411024bad7382be12381413e Mon Sep 17 00:00:00 2001
From: Chen Chengjun <chenchengjun.ccj@antgroup.com>
Date: Mon, 3 Mar 2025 14:32:19 +0800
Subject: [PATCH] Make the IRQ state within the trap correct

---
 ostd/src/arch/x86/trap/mod.rs | 25 +++++++++++++++++++++++++
 1 file changed, 25 insertions(+)

diff --git a/ostd/src/arch/x86/trap/mod.rs b/ostd/src/arch/x86/trap/mod.rs
index fe570d3c..500783d0 100644
--- a/ostd/src/arch/x86/trap/mod.rs
+++ b/ostd/src/arch/x86/trap/mod.rs
@@ -26,6 +26,7 @@ use log::debug;
 
 use super::ex_table::ExTable;
 use crate::{
+    arch::irq::{disable_local, enable_local},
     cpu::{CpuException, CpuExceptionInfo, PageFaultErrorCode},
     cpu_local_cell,
     mm::{
@@ -220,16 +221,38 @@ pub fn is_kernel_interrupted() -> bool {
 /// Handle traps (only from kernel).
 #[no_mangle]
 extern "sysv64" fn trap_handler(f: &mut TrapFrame) {
+    fn enable_local_if(cond: bool) {
+        if cond {
+            enable_local();
+        }
+    }
+
+    fn disable_local_if(cond: bool) {
+        if cond {
+            disable_local();
+        }
+    }
+
+    // The IRQ state before trapping. We need to ensure that the IRQ state
+    // during exception handling is consistent with the state before the trap.
+    let was_irq_enabled =
+        f.rflags as u64 & x86_64::registers::rflags::RFlags::INTERRUPT_FLAG.bits() > 0;
+
     match CpuException::to_cpu_exception(f.trap_num as u16) {
         #[cfg(feature = "cvm_guest")]
         Some(CpuException::VIRTUALIZATION_EXCEPTION) => {
             let ve_info = tdcall::get_veinfo().expect("#VE handler: fail to get VE info\n");
+            // We need to enable interrupts only after `tdcall::get_veinfo` is called
+            // to avoid nested `#VE`s.
+            enable_local_if(was_irq_enabled);
             let mut trapframe_wrapper = TrapFrameWrapper(&mut *f);
             handle_virtual_exception(&mut trapframe_wrapper, &ve_info);
             *f = *trapframe_wrapper.0;
+            disable_local_if(was_irq_enabled);
         }
         Some(CpuException::PAGE_FAULT) => {
             let page_fault_addr = x86_64::registers::control::Cr2::read_raw();
+            enable_local_if(was_irq_enabled);
             // The actual user space implementation should be responsible
             // for providing mechanism to treat the 0 virtual address.
             if (0..MAX_USERSPACE_VADDR).contains(&(page_fault_addr as usize)) {
@@ -237,8 +260,10 @@ extern "sysv64" fn trap_handler(f: &mut TrapFrame) {
             } else {
                 handle_kernel_page_fault(f, page_fault_addr);
             }
+            disable_local_if(was_irq_enabled);
         }
         Some(exception) => {
+            enable_local_if(was_irq_enabled);
             panic!(
                 "cannot handle kernel CPU exception: {:?}, trapframe: {:?}",
                 exception, f