diff --git a/kernel/comps/pci/src/common_device.rs b/kernel/comps/pci/src/common_device.rs index 3aad94cbf..13f6a2104 100644 --- a/kernel/comps/pci/src/common_device.rs +++ b/kernel/comps/pci/src/common_device.rs @@ -52,6 +52,11 @@ impl PciCommonDevice { self.header_type.device_type() } + /// Checks whether the device is a multi-function device + pub fn has_multi_funcs(&self) -> bool { + self.header_type.has_multi_funcs() + } + /// Gets the PCI Command pub fn command(&self) -> Command { Command::from_bits_truncate(self.location.read16(PciCommonCfgOffset::Command as u16)) diff --git a/kernel/comps/pci/src/device_info.rs b/kernel/comps/pci/src/device_info.rs index 41ea52e03..76e7faef2 100644 --- a/kernel/comps/pci/src/device_info.rs +++ b/kernel/comps/pci/src/device_info.rs @@ -19,33 +19,17 @@ impl PciDeviceLocation { // TODO: Find a proper way to obtain the bus range. For example, if the PCI bus is identified // from a device tree, this information can be obtained from the `bus-range` field (e.g., // `bus-range = <0x00 0x7f>`). - const MIN_BUS: u8 = 0; + pub const MIN_BUS: u8 = 0; #[cfg(not(target_arch = "loongarch64"))] - const MAX_BUS: u8 = 255; + pub const MAX_BUS: u8 = 255; #[cfg(target_arch = "loongarch64")] - const MAX_BUS: u8 = 127; + pub const MAX_BUS: u8 = 127; - const MIN_DEVICE: u8 = 0; - const MAX_DEVICE: u8 = 31; + pub const MIN_DEVICE: u8 = 0; + pub const MAX_DEVICE: u8 = 31; - const MIN_FUNCTION: u8 = 0; - const MAX_FUNCTION: u8 = 7; - - /// Returns an iterator that enumerates all possible PCI device locations. - pub fn all() -> impl Iterator { - let all_bus = Self::MIN_BUS..=Self::MAX_BUS; - let all_dev = Self::MIN_DEVICE..=Self::MAX_DEVICE; - let all_func = Self::MIN_FUNCTION..=Self::MAX_FUNCTION; - - all_bus - .flat_map(move |bus| all_dev.clone().map(move |dev| (bus, dev))) - .flat_map(move |(bus, dev)| all_func.clone().map(move |func| (bus, dev, func))) - .map(|(bus, dev, func)| PciDeviceLocation { - bus, - device: dev, - function: func, - }) - } + pub const MIN_FUNCTION: u8 = 0; + pub const MAX_FUNCTION: u8 = 7; } impl PciDeviceLocation { diff --git a/kernel/comps/pci/src/lib.rs b/kernel/comps/pci/src/lib.rs index 764b81944..7736f9b71 100644 --- a/kernel/comps/pci/src/lib.rs +++ b/kernel/comps/pci/src/lib.rs @@ -96,10 +96,34 @@ fn init() { } let mut lock = PCI_BUS.lock(); - for location in PciDeviceLocation::all() { - let Some(device) = PciCommonDevice::new(location) else { - continue; - }; - lock.register_common_device(device); + + let all_bus = PciDeviceLocation::MIN_BUS..=PciDeviceLocation::MAX_BUS; + let all_dev = PciDeviceLocation::MIN_DEVICE..=PciDeviceLocation::MAX_DEVICE; + let all_func = PciDeviceLocation::MIN_FUNCTION..=PciDeviceLocation::MAX_FUNCTION; + + for bus in all_bus { + for device in all_dev.clone() { + let mut device_location = PciDeviceLocation { + bus, + device, + function: PciDeviceLocation::MIN_FUNCTION, + }; + + let Some(first_function_device) = PciCommonDevice::new(device_location) else { + continue; + }; + let has_multi_function = first_function_device.has_multi_funcs(); + // Register function 0 in advance + lock.register_common_device(first_function_device); + + if has_multi_function { + for function in all_func.clone().skip(1) { + device_location.function = function; + if let Some(common_device) = PciCommonDevice::new(device_location) { + lock.register_common_device(common_device); + } + } + } + } } }