Check CPUID before enabling AVX512

This commit is contained in:
Zhang Junyang 2024-07-17 00:15:54 +08:00 committed by Tate, Hongliang Tian
parent da987db700
commit 8a9c012249
2 changed files with 25 additions and 7 deletions

6
Cargo.lock generated
View File

@ -1455,13 +1455,13 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]] [[package]]
name = "tdx-guest" name = "tdx-guest"
version = "0.1.0" version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "675bd99b7b81320678a9e9e524041b1f57fa68c965524413740d6cbe83d687f6" checksum = "d08fda76b8a438b7d926a92217a709ba39ef9031b8544a9a3f3af08d1b3f87e9"
dependencies = [ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
"iced-x86",
"lazy_static", "lazy_static",
"rand_core",
"raw-cpuid", "raw-cpuid",
"x86_64", "x86_64",
] ]

View File

@ -122,6 +122,20 @@ pub fn read_random() -> Option<u64> {
None None
} }
fn has_avx512() -> bool {
use core::arch::x86_64::{__cpuid, __cpuid_count};
let cpuid_result = unsafe { __cpuid(0) };
if cpuid_result.eax < 7 {
// CPUID function 7 is not supported
return false;
}
let cpuid_result = unsafe { __cpuid_count(7, 0) };
// Check for AVX-512 Foundation (bit 16 of ebx)
cpuid_result.ebx & (1 << 16) != 0
}
fn enable_common_cpu_features() { fn enable_common_cpu_features() {
use x86_64::registers::{control::Cr4Flags, model_specific::EferFlags, xcontrol::XCr0Flags}; use x86_64::registers::{control::Cr4Flags, model_specific::EferFlags, xcontrol::XCr0Flags};
let mut cr4 = x86_64::registers::control::Cr4::read(); let mut cr4 = x86_64::registers::control::Cr4::read();
@ -135,10 +149,14 @@ fn enable_common_cpu_features() {
} }
let mut xcr0 = x86_64::registers::xcontrol::XCr0::read(); let mut xcr0 = x86_64::registers::xcontrol::XCr0::read();
xcr0 |= XCr0Flags::AVX | XCr0Flags::SSE;
if has_avx512() {
// TODO: Ensure proper saving and restoring of floating-point states // TODO: Ensure proper saving and restoring of floating-point states
// to correctly support advanced instructions like AVX-512. // to correctly support advanced instructions like AVX-512.
let avx512 = XCr0Flags::OPMASK | XCr0Flags::ZMM_HI256 | XCr0Flags::HI16_ZMM; xcr0 |= XCr0Flags::OPMASK | XCr0Flags::ZMM_HI256 | XCr0Flags::HI16_ZMM;
xcr0 |= XCr0Flags::AVX | XCr0Flags::SSE | avx512; }
unsafe { unsafe {
x86_64::registers::xcontrol::XCr0::write(xcr0); x86_64::registers::xcontrol::XCr0::write(xcr0);
} }