1
1
Fork 0
forked from suyu/suyu

kernel: be more careful about kernel address keys

This commit is contained in:
Liam 2023-02-23 20:32:03 -05:00
parent c4ba088a5d
commit 97f7f7bad5
5 changed files with 23 additions and 11 deletions

View file

@ -113,7 +113,7 @@ Result KConditionVariable::SignalToAddress(VAddr addr) {
// Remove waiter thread. // Remove waiter thread.
bool has_waiters{}; bool has_waiters{};
KThread* const next_owner_thread = KThread* const next_owner_thread =
owner_thread->RemoveWaiterByKey(std::addressof(has_waiters), addr); owner_thread->RemoveUserWaiterByKey(std::addressof(has_waiters), addr);
// Determine the next tag. // Determine the next tag.
u32 next_value{}; u32 next_value{};
@ -283,7 +283,7 @@ Result KConditionVariable::Wait(VAddr addr, u64 key, u32 value, s64 timeout) {
// Remove waiter thread. // Remove waiter thread.
bool has_waiters{}; bool has_waiters{};
KThread* next_owner_thread = KThread* next_owner_thread =
cur_thread->RemoveWaiterByKey(std::addressof(has_waiters), addr); cur_thread->RemoveUserWaiterByKey(std::addressof(has_waiters), addr);
// Update for the next owner thread. // Update for the next owner thread.
u32 next_value{}; u32 next_value{};

View file

@ -91,7 +91,7 @@ void KLightLock::UnlockSlowPath(uintptr_t _cur_thread) {
// Get the next owner. // Get the next owner.
bool has_waiters; bool has_waiters;
KThread* next_owner = owner_thread->RemoveWaiterByKey( KThread* next_owner = owner_thread->RemoveKernelWaiterByKey(
std::addressof(has_waiters), reinterpret_cast<uintptr_t>(std::addressof(tag))); std::addressof(has_waiters), reinterpret_cast<uintptr_t>(std::addressof(tag)));
// Pass the lock to the next owner. // Pass the lock to the next owner.

View file

@ -157,7 +157,7 @@ bool KProcess::ReleaseUserException(KThread* thread) {
// Remove waiter thread. // Remove waiter thread.
bool has_waiters{}; bool has_waiters{};
if (KThread* next = thread->RemoveWaiterByKey( if (KThread* next = thread->RemoveKernelWaiterByKey(
std::addressof(has_waiters), std::addressof(has_waiters),
reinterpret_cast<uintptr_t>(std::addressof(exception_thread))); reinterpret_cast<uintptr_t>(std::addressof(exception_thread)));
next != nullptr) { next != nullptr) {

View file

@ -933,12 +933,14 @@ void KThread::AddHeldLock(LockWithPriorityInheritanceInfo* lock_info) {
held_lock_info_list.push_front(*lock_info); held_lock_info_list.push_front(*lock_info);
} }
KThread::LockWithPriorityInheritanceInfo* KThread::FindHeldLock(VAddr address_key_) { KThread::LockWithPriorityInheritanceInfo* KThread::FindHeldLock(VAddr address_key_,
bool is_kernel_address_key_) {
ASSERT(KScheduler::IsSchedulerLockedByCurrentThread(kernel)); ASSERT(KScheduler::IsSchedulerLockedByCurrentThread(kernel));
// Try to find an existing held lock. // Try to find an existing held lock.
for (auto& held_lock : held_lock_info_list) { for (auto& held_lock : held_lock_info_list) {
if (held_lock.GetAddressKey() == address_key_) { if (held_lock.GetAddressKey() == address_key_ &&
held_lock.GetIsKernelAddressKey() == is_kernel_address_key_) {
return std::addressof(held_lock); return std::addressof(held_lock);
} }
} }
@ -961,7 +963,7 @@ void KThread::AddWaiterImpl(KThread* thread) {
} }
// Get the relevant lock info. // Get the relevant lock info.
auto* lock_info = this->FindHeldLock(address_key_); auto* lock_info = this->FindHeldLock(address_key_, is_kernel_address_key_);
if (lock_info == nullptr) { if (lock_info == nullptr) {
// Create a new lock for the address key. // Create a new lock for the address key.
lock_info = lock_info =
@ -1067,11 +1069,11 @@ void KThread::RemoveWaiter(KThread* thread) {
} }
} }
KThread* KThread::RemoveWaiterByKey(bool* out_has_waiters, VAddr key) { KThread* KThread::RemoveWaiterByKey(bool* out_has_waiters, VAddr key, bool is_kernel_address_key_) {
ASSERT(KScheduler::IsSchedulerLockedByCurrentThread(kernel)); ASSERT(KScheduler::IsSchedulerLockedByCurrentThread(kernel));
// Get the relevant lock info. // Get the relevant lock info.
auto* lock_info = this->FindHeldLock(key); auto* lock_info = this->FindHeldLock(key, is_kernel_address_key_);
if (lock_info == nullptr) { if (lock_info == nullptr) {
*out_has_waiters = false; *out_has_waiters = false;
return nullptr; return nullptr;

View file

@ -595,7 +595,13 @@ public:
[[nodiscard]] Result GetThreadContext3(std::vector<u8>& out); [[nodiscard]] Result GetThreadContext3(std::vector<u8>& out);
[[nodiscard]] KThread* RemoveWaiterByKey(bool* out_has_waiters, VAddr key); [[nodiscard]] KThread* RemoveUserWaiterByKey(bool* out_has_waiters, VAddr key) {
return this->RemoveWaiterByKey(out_has_waiters, key, false);
}
[[nodiscard]] KThread* RemoveKernelWaiterByKey(bool* out_has_waiters, VAddr key) {
return this->RemoveWaiterByKey(out_has_waiters, key, true);
}
[[nodiscard]] VAddr GetAddressKey() const { [[nodiscard]] VAddr GetAddressKey() const {
return address_key; return address_key;
@ -666,6 +672,9 @@ public:
} }
private: private:
[[nodiscard]] KThread* RemoveWaiterByKey(bool* out_has_waiters, VAddr key,
bool is_kernel_address_key);
static constexpr size_t PriorityInheritanceCountMax = 10; static constexpr size_t PriorityInheritanceCountMax = 10;
union SyncObjectBuffer { union SyncObjectBuffer {
std::array<KSynchronizationObject*, Svc::ArgumentHandleCountMax> sync_objects{}; std::array<KSynchronizationObject*, Svc::ArgumentHandleCountMax> sync_objects{};
@ -850,7 +859,7 @@ public:
} }
void AddHeldLock(LockWithPriorityInheritanceInfo* lock_info); void AddHeldLock(LockWithPriorityInheritanceInfo* lock_info);
LockWithPriorityInheritanceInfo* FindHeldLock(VAddr address_key); LockWithPriorityInheritanceInfo* FindHeldLock(VAddr address_key, bool is_kernel_address_key);
private: private:
using LockWithPriorityInheritanceInfoList = using LockWithPriorityInheritanceInfoList =
@ -926,6 +935,7 @@ public:
condvar_key = cv_key; condvar_key = cv_key;
address_key = address; address_key = address;
address_key_value = value; address_key_value = value;
is_kernel_address_key = false;
} }
void ClearConditionVariable() { void ClearConditionVariable() {