1
1
Fork 0
forked from suyu/suyu

Tests: Add tests for fibers and refactor/fix Fiber class

This commit is contained in:
Fernando Sahmkow 2020-02-05 14:13:16 -04:00
parent bc266a9d98
commit 8d0e3c5422
4 changed files with 247 additions and 19 deletions

View file

@ -3,18 +3,21 @@
// Refer to the license.txt file included. // Refer to the license.txt file included.
#include "common/fiber.h" #include "common/fiber.h"
#ifdef _MSC_VER
#include <windows.h>
#else
#include <boost/context/detail/fcontext.hpp>
#endif
namespace Common { namespace Common {
#ifdef _MSC_VER #ifdef _MSC_VER
#include <windows.h>
struct Fiber::FiberImpl { struct Fiber::FiberImpl {
LPVOID handle = nullptr; LPVOID handle = nullptr;
}; };
void Fiber::_start([[maybe_unused]] void* parameter) { void Fiber::start() {
guard.lock();
if (previous_fiber) { if (previous_fiber) {
previous_fiber->guard.unlock(); previous_fiber->guard.unlock();
previous_fiber = nullptr; previous_fiber = nullptr;
@ -22,10 +25,10 @@ void Fiber::_start([[maybe_unused]] void* parameter) {
entry_point(start_parameter); entry_point(start_parameter);
} }
static void __stdcall FiberStartFunc(LPVOID lpFiberParameter) void __stdcall Fiber::FiberStartFunc(void* fiber_parameter)
{ {
auto fiber = static_cast<Fiber *>(lpFiberParameter); auto fiber = static_cast<Fiber *>(fiber_parameter);
fiber->_start(nullptr); fiber->start();
} }
Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter) Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter)
@ -74,30 +77,26 @@ std::shared_ptr<Fiber> Fiber::ThreadToFiber() {
#else #else
#include <boost/context/detail/fcontext.hpp>
constexpr std::size_t default_stack_size = 1024 * 1024 * 4; // 4MB constexpr std::size_t default_stack_size = 1024 * 1024 * 4; // 4MB
struct Fiber::FiberImpl { struct alignas(64) Fiber::FiberImpl {
boost::context::detail::fcontext_t context;
std::array<u8, default_stack_size> stack; std::array<u8, default_stack_size> stack;
boost::context::detail::fcontext_t context;
}; };
void Fiber::_start(void* parameter) { void Fiber::start(boost::context::detail::transfer_t& transfer) {
guard.lock();
boost::context::detail::transfer_t* transfer = static_cast<boost::context::detail::transfer_t*>(parameter);
if (previous_fiber) { if (previous_fiber) {
previous_fiber->impl->context = transfer->fctx; previous_fiber->impl->context = transfer.fctx;
previous_fiber->guard.unlock(); previous_fiber->guard.unlock();
previous_fiber = nullptr; previous_fiber = nullptr;
} }
entry_point(start_parameter); entry_point(start_parameter);
} }
static void FiberStartFunc(boost::context::detail::transfer_t transfer) void Fiber::FiberStartFunc(boost::context::detail::transfer_t transfer)
{ {
auto fiber = static_cast<Fiber *>(transfer.data); auto fiber = static_cast<Fiber *>(transfer.data);
fiber->_start(&transfer); fiber->start(transfer);
} }
Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter) Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter)
@ -139,6 +138,7 @@ void Fiber::YieldTo(std::shared_ptr<Fiber> from, std::shared_ptr<Fiber> to) {
std::shared_ptr<Fiber> Fiber::ThreadToFiber() { std::shared_ptr<Fiber> Fiber::ThreadToFiber() {
std::shared_ptr<Fiber> fiber = std::shared_ptr<Fiber>{new Fiber()}; std::shared_ptr<Fiber> fiber = std::shared_ptr<Fiber>{new Fiber()};
fiber->guard.lock();
fiber->is_thread_fiber = true; fiber->is_thread_fiber = true;
return fiber; return fiber;
} }

View file

@ -10,6 +10,12 @@
#include "common/common_types.h" #include "common/common_types.h"
#include "common/spin_lock.h" #include "common/spin_lock.h"
#ifndef _MSC_VER
namespace boost::context::detail {
struct transfer_t;
}
#endif
namespace Common { namespace Common {
class Fiber { class Fiber {
@ -31,9 +37,6 @@ public:
/// Only call from main thread's fiber /// Only call from main thread's fiber
void Exit(); void Exit();
/// Used internally but required to be public, Shall not be used
void _start(void* parameter);
/// Changes the start parameter of the fiber. Has no effect if the fiber already started /// Changes the start parameter of the fiber. Has no effect if the fiber already started
void SetStartParameter(void* new_parameter) { void SetStartParameter(void* new_parameter) {
start_parameter = new_parameter; start_parameter = new_parameter;
@ -42,6 +45,16 @@ public:
private: private:
Fiber(); Fiber();
#ifdef _MSC_VER
void start();
static void FiberStartFunc(void* fiber_parameter);
#else
void start(boost::context::detail::transfer_t& transfer);
static void FiberStartFunc(boost::context::detail::transfer_t transfer);
#endif
struct FiberImpl; struct FiberImpl;
SpinLock guard; SpinLock guard;

View file

@ -1,6 +1,7 @@
add_executable(tests add_executable(tests
common/bit_field.cpp common/bit_field.cpp
common/bit_utils.cpp common/bit_utils.cpp
common/fibers.cpp
common/multi_level_queue.cpp common/multi_level_queue.cpp
common/param_package.cpp common/param_package.cpp
common/ring_buffer.cpp common/ring_buffer.cpp

214
src/tests/common/fibers.cpp Normal file
View file

@ -0,0 +1,214 @@
// Copyright 2020 yuzu Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#include <atomic>
#include <cstdlib>
#include <functional>
#include <memory>
#include <thread>
#include <unordered_map>
#include <vector>
#include <catch2/catch.hpp>
#include <math.h>
#include "common/common_types.h"
#include "common/fiber.h"
#include "common/spin_lock.h"
namespace Common {
class TestControl1 {
public:
TestControl1() = default;
void DoWork();
void ExecuteThread(u32 id);
std::unordered_map<std::thread::id, u32> ids;
std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
std::vector<std::shared_ptr<Common::Fiber>> work_fibers;
std::vector<u32> items;
std::vector<u32> results;
};
static void WorkControl1(void* control) {
TestControl1* test_control = static_cast<TestControl1*>(control);
test_control->DoWork();
}
void TestControl1::DoWork() {
std::thread::id this_id = std::this_thread::get_id();
u32 id = ids[this_id];
u32 value = items[id];
for (u32 i = 0; i < id; i++) {
value++;
}
results[id] = value;
Fiber::YieldTo(work_fibers[id], thread_fibers[id]);
}
void TestControl1::ExecuteThread(u32 id) {
std::thread::id this_id = std::this_thread::get_id();
ids[this_id] = id;
auto thread_fiber = Fiber::ThreadToFiber();
thread_fibers[id] = thread_fiber;
work_fibers[id] = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl1}, this);
items[id] = rand() % 256;
Fiber::YieldTo(thread_fibers[id], work_fibers[id]);
thread_fibers[id]->Exit();
}
static void ThreadStart1(u32 id, TestControl1& test_control) {
test_control.ExecuteThread(id);
}
TEST_CASE("Fibers::Setup", "[common]") {
constexpr u32 num_threads = 7;
TestControl1 test_control{};
test_control.thread_fibers.resize(num_threads, nullptr);
test_control.work_fibers.resize(num_threads, nullptr);
test_control.items.resize(num_threads, 0);
test_control.results.resize(num_threads, 0);
std::vector<std::thread> threads;
for (u32 i = 0; i < num_threads; i++) {
threads.emplace_back(ThreadStart1, i, std::ref(test_control));
}
for (u32 i = 0; i < num_threads; i++) {
threads[i].join();
}
for (u32 i = 0; i < num_threads; i++) {
REQUIRE(test_control.items[i] + i == test_control.results[i]);
}
}
class TestControl2 {
public:
TestControl2() = default;
void DoWork1() {
trap2 = false;
while (trap.load());
for (u32 i = 0; i < 12000; i++) {
value1 += i;
}
Fiber::YieldTo(fiber1, fiber3);
std::thread::id this_id = std::this_thread::get_id();
u32 id = ids[this_id];
assert1 = id == 1;
value2 += 5000;
Fiber::YieldTo(fiber1, thread_fibers[id]);
}
void DoWork2() {
while (trap2.load());
value2 = 2000;
trap = false;
Fiber::YieldTo(fiber2, fiber1);
assert3 = false;
}
void DoWork3() {
std::thread::id this_id = std::this_thread::get_id();
u32 id = ids[this_id];
assert2 = id == 0;
value1 += 1000;
Fiber::YieldTo(fiber3, thread_fibers[id]);
}
void ExecuteThread(u32 id);
void CallFiber1() {
std::thread::id this_id = std::this_thread::get_id();
u32 id = ids[this_id];
Fiber::YieldTo(thread_fibers[id], fiber1);
}
void CallFiber2() {
std::thread::id this_id = std::this_thread::get_id();
u32 id = ids[this_id];
Fiber::YieldTo(thread_fibers[id], fiber2);
}
void Exit();
bool assert1{};
bool assert2{};
bool assert3{true};
u32 value1{};
u32 value2{};
std::atomic<bool> trap{true};
std::atomic<bool> trap2{true};
std::unordered_map<std::thread::id, u32> ids;
std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
std::shared_ptr<Common::Fiber> fiber1;
std::shared_ptr<Common::Fiber> fiber2;
std::shared_ptr<Common::Fiber> fiber3;
};
static void WorkControl2_1(void* control) {
TestControl2* test_control = static_cast<TestControl2*>(control);
test_control->DoWork1();
}
static void WorkControl2_2(void* control) {
TestControl2* test_control = static_cast<TestControl2*>(control);
test_control->DoWork2();
}
static void WorkControl2_3(void* control) {
TestControl2* test_control = static_cast<TestControl2*>(control);
test_control->DoWork3();
}
void TestControl2::ExecuteThread(u32 id) {
std::thread::id this_id = std::this_thread::get_id();
ids[this_id] = id;
auto thread_fiber = Fiber::ThreadToFiber();
thread_fibers[id] = thread_fiber;
}
void TestControl2::Exit() {
std::thread::id this_id = std::this_thread::get_id();
u32 id = ids[this_id];
thread_fibers[id]->Exit();
}
static void ThreadStart2_1(u32 id, TestControl2& test_control) {
test_control.ExecuteThread(id);
test_control.CallFiber1();
test_control.Exit();
}
static void ThreadStart2_2(u32 id, TestControl2& test_control) {
test_control.ExecuteThread(id);
test_control.CallFiber2();
test_control.Exit();
}
TEST_CASE("Fibers::InterExchange", "[common]") {
TestControl2 test_control{};
test_control.thread_fibers.resize(2, nullptr);
test_control.fiber1 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_1}, &test_control);
test_control.fiber2 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_2}, &test_control);
test_control.fiber3 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_3}, &test_control);
std::thread thread1(ThreadStart2_1, 0, std::ref(test_control));
std::thread thread2(ThreadStart2_2, 1, std::ref(test_control));
thread1.join();
thread2.join();
REQUIRE(test_control.assert1);
REQUIRE(test_control.assert2);
REQUIRE(test_control.assert3);
REQUIRE(test_control.value2 == 7000);
u32 cal_value = 0;
for (u32 i = 0; i < 12000; i++) {
cal_value += i;
}
cal_value += 1000;
REQUIRE(test_control.value1 == cal_value);
}
} // namespace Common