diff --git a/kernel/src/sched/sched_class/fair.rs b/kernel/src/sched/sched_class/fair.rs index 2e5440e99..745574e57 100644 --- a/kernel/src/sched/sched_class/fair.rs +++ b/kernel/src/sched/sched_class/fair.rs @@ -3,7 +3,7 @@ use alloc::{collections::BinaryHeap, sync::Arc}; use core::{ cmp::{self, Reverse}, - sync::atomic::{AtomicU64, Ordering::Relaxed}, + sync::atomic::{AtomicU64, Ordering}, }; use ostd::{ @@ -24,6 +24,9 @@ use crate::{ }; const WEIGHT_0: u64 = 1024; + +const HAS_PENDING: u64 = 1 << (u64::BITS - 1); + pub const fn nice_to_weight(nice: Nice) -> u64 { // Calculated by the formula below: // @@ -53,6 +56,7 @@ pub const fn nice_to_weight(nice: Nice) -> u64 { WEIGHT_0 * numerator / denominator } }; + assert!(ret[index] & HAS_PENDING == 0); index += 1; nice += 1; @@ -93,9 +97,38 @@ pub const fn nice_to_weight(nice: Nice) -> u64 { /// /// period_delta > time_slice /// || vruntime > rq_min_vruntime + normalized_time_slice +/// +/// # The weight update process +/// +/// The weight of a thread can be updated by the `sched_setattr` syscall series in +/// any thread. This makes it difficult to re-evaluate the data of its run queue +/// instantly after the update without a direct backward reference (which is +/// impossible to be represented in safe Rust). +/// +/// To handle this problem, we use a `pending_weight` field to store the new weight. +/// When the thread is scheduled within the run queue, we will check if the weight +/// needs to be updated since both the old and new weights are needed for re-evaluation. +/// +/// To indicate whether the weight needs to be updated, we pack the `weight` field +/// with a bit flag `HAS_PENDING`. The overall mechanism is similar to an optimized +/// version of spin locks. When accessing the `weight` field: +/// +/// - If the weight does not need to be updated (i.e. `weight & IS_PENDING == 0`), +/// we simply return the weight. +/// - If the weight needs to be updated (i.e. `weight & IS_PENDING != 0`), we try to +/// store the new weight into the `weight` field with `IS_PENDING` cleared via a +/// `compare_exchange_weak` loop, which shouldn't take too much time since the update +/// frequency is usually relatively low. +/// - If the result of the loop turns out that the weight doesn't need to be updated, we +/// return the weight directly. +/// - After a successful update, we re-evaluate the data of the run queue. +/// +/// This method allows the access to the weight lock-free and ensures only 1 load +/// is needed most of the time. #[derive(Debug)] pub struct FairAttr { weight: AtomicU64, + pending_weight: AtomicU64, vruntime: AtomicU64, } @@ -103,19 +136,53 @@ impl FairAttr { pub fn new(nice: Nice) -> Self { FairAttr { weight: nice_to_weight(nice).into(), + pending_weight: Default::default(), vruntime: Default::default(), } } pub fn update(&self, nice: Nice) { - self.weight.store(nice_to_weight(nice), Relaxed); + self.pending_weight + .store(nice_to_weight(nice), Ordering::Relaxed); + self.weight.fetch_or(HAS_PENDING, Ordering::Release); } - fn update_vruntime(&self, delta: u64) -> (u64, u64) { - let weight = self.weight.load(Relaxed); + fn update_vruntime(&self, delta: u64, weight: u64) -> u64 { let delta = delta * WEIGHT_0 / weight; - let vruntime = self.vruntime.fetch_add(delta, Relaxed) + delta; - (vruntime, weight) + self.vruntime.fetch_add(delta, Ordering::Relaxed) + delta + } + + fn fetch_weight(&self) -> (u64, u64) { + let mut weight = self.weight.load(Ordering::Acquire); + if weight & HAS_PENDING == 0 { + return (weight, weight); + } + + let mut new_weight = self.pending_weight.load(Ordering::Relaxed); + loop { + match self.weight.compare_exchange_weak( + weight, + new_weight, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => break, + Err(failure) => { + if failure & HAS_PENDING == 0 { + return (failure, failure); + } + weight = failure; + new_weight = self.pending_weight.load(Ordering::Relaxed); + } + } + } + let old_weight = weight & !HAS_PENDING; + + let vruntime = self.vruntime.load(Ordering::Relaxed); + self.vruntime + .store(vruntime * old_weight / new_weight, Ordering::Relaxed); + + (old_weight, new_weight) } } @@ -226,12 +293,14 @@ impl SchedClassRq for FairClassRq { Some(EnqueueFlags::Spawn) => self.min_vruntime + self.vtime_slice(), _ => self.min_vruntime, }; + let (_old_weight, weight) = fair_attr.fetch_weight(); + let vruntime = fair_attr .vruntime - .fetch_max(vruntime, Relaxed) + .fetch_max(vruntime, Ordering::Relaxed) .max(vruntime); - self.total_weight += fair_attr.weight.load(Relaxed); + self.total_weight += weight; self.entities.push(Reverse(FairQueueItem(entity, vruntime))); } @@ -247,7 +316,12 @@ impl SchedClassRq for FairClassRq { let Reverse(FairQueueItem(entity, _)) = self.entities.pop()?; let sched_attr = entity.as_thread().unwrap().sched_attr(); - self.total_weight -= sched_attr.fair.weight.load(Relaxed); + let (old_weight, _weight) = sched_attr.fair.fetch_weight(); + // Equals to: + // + // self.total_weight = self.total_weight + weight - old_weight; + // self.total_weight -= weight; + self.total_weight -= old_weight; Some(entity) } @@ -261,7 +335,11 @@ impl SchedClassRq for FairClassRq { match flags { UpdateFlags::Yield => true, UpdateFlags::Tick | UpdateFlags::Wait => { - let (vruntime, weight) = attr.fair.update_vruntime(rt.delta); + let (old_weight, weight) = attr.fair.fetch_weight(); + if old_weight != weight { + self.total_weight = self.total_weight + weight - old_weight; + } + let vruntime = attr.fair.update_vruntime(rt.delta, weight); self.min_vruntime = match self.entities.peek() { Some(Reverse(leftmost)) => vruntime.min(leftmost.key()), None => vruntime, diff --git a/kernel/src/sched/sched_class/mod.rs b/kernel/src/sched/sched_class/mod.rs index 6f933923b..7a0b31d38 100644 --- a/kernel/src/sched/sched_class/mod.rs +++ b/kernel/src/sched/sched_class/mod.rs @@ -179,7 +179,17 @@ impl SchedAttr { } pub fn update_policy(&self, f: impl FnOnce(&mut SchedPolicy) -> T) -> T { - self.policy.update(f) + self.policy.update(|policy| { + let ret = f(policy); + match *policy { + SchedPolicy::RealTime { rt_prio, rt_policy } => { + self.real_time.update(rt_prio.get(), rt_policy); + } + SchedPolicy::Fair(nice) => self.fair.update(nice), + _ => {} + } + ret + }) } fn last_cpu(&self) -> Option {