diff options
Diffstat (limited to 'src/tests/common')
| -rw-r--r-- | src/tests/common/fibers.cpp | 71 | 
1 files changed, 40 insertions, 31 deletions
diff --git a/src/tests/common/fibers.cpp b/src/tests/common/fibers.cpp index 4fd92428f..4757dd2b4 100644 --- a/src/tests/common/fibers.cpp +++ b/src/tests/common/fibers.cpp @@ -6,18 +6,40 @@  #include <cstdlib>  #include <functional>  #include <memory> +#include <mutex> +#include <stdexcept>  #include <thread>  #include <unordered_map>  #include <vector>  #include <catch2/catch.hpp> -#include <math.h> +  #include "common/common_types.h"  #include "common/fiber.h" -#include "common/spin_lock.h"  namespace Common { +class ThreadIds { +public: +    void Register(u32 id) { +        const auto thread_id = std::this_thread::get_id(); +        std::scoped_lock lock{mutex}; +        if (ids.contains(thread_id)) { +            throw std::logic_error{"Registering the same thread twice"}; +        } +        ids.emplace(thread_id, id); +    } + +    [[nodiscard]] u32 Get() const { +        std::scoped_lock lock{mutex}; +        return ids.at(std::this_thread::get_id()); +    } + +private: +    mutable std::mutex mutex; +    std::unordered_map<std::thread::id, u32> ids; +}; +  class TestControl1 {  public:      TestControl1() = default; @@ -26,7 +48,7 @@ public:      void ExecuteThread(u32 id); -    std::unordered_map<std::thread::id, u32> ids; +    ThreadIds thread_ids;      std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;      std::vector<std::shared_ptr<Common::Fiber>> work_fibers;      std::vector<u32> items; @@ -39,8 +61,7 @@ static void WorkControl1(void* control) {  }  void TestControl1::DoWork() { -    std::thread::id this_id = std::this_thread::get_id(); -    u32 id = ids[this_id]; +    const u32 id = thread_ids.Get();      u32 value = items[id];      for (u32 i = 0; i < id; i++) {          value++; @@ -50,8 +71,7 @@ void TestControl1::DoWork() {  }  void TestControl1::ExecuteThread(u32 id) { -    std::thread::id this_id = std::this_thread::get_id(); -    ids[this_id] = id; +    thread_ids.Register(id);      auto thread_fiber = Fiber::ThreadToFiber();      thread_fibers[id] = thread_fiber;      work_fibers[id] = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl1}, this); @@ -98,8 +118,7 @@ public:              value1 += i;          }          Fiber::YieldTo(fiber1, fiber3); -        std::thread::id this_id = std::this_thread::get_id(); -        u32 id = ids[this_id]; +        const u32 id = thread_ids.Get();          assert1 = id == 1;          value2 += 5000;          Fiber::YieldTo(fiber1, thread_fibers[id]); @@ -115,8 +134,7 @@ public:      }      void DoWork3() { -        std::thread::id this_id = std::this_thread::get_id(); -        u32 id = ids[this_id]; +        const u32 id = thread_ids.Get();          assert2 = id == 0;          value1 += 1000;          Fiber::YieldTo(fiber3, thread_fibers[id]); @@ -125,14 +143,12 @@ public:      void ExecuteThread(u32 id);      void CallFiber1() { -        std::thread::id this_id = std::this_thread::get_id(); -        u32 id = ids[this_id]; +        const u32 id = thread_ids.Get();          Fiber::YieldTo(thread_fibers[id], fiber1);      }      void CallFiber2() { -        std::thread::id this_id = std::this_thread::get_id(); -        u32 id = ids[this_id]; +        const u32 id = thread_ids.Get();          Fiber::YieldTo(thread_fibers[id], fiber2);      } @@ -145,7 +161,7 @@ public:      u32 value2{};      std::atomic<bool> trap{true};      std::atomic<bool> trap2{true}; -    std::unordered_map<std::thread::id, u32> ids; +    ThreadIds thread_ids;      std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;      std::shared_ptr<Common::Fiber> fiber1;      std::shared_ptr<Common::Fiber> fiber2; @@ -168,15 +184,13 @@ static void WorkControl2_3(void* control) {  }  void TestControl2::ExecuteThread(u32 id) { -    std::thread::id this_id = std::this_thread::get_id(); -    ids[this_id] = id; +    thread_ids.Register(id);      auto thread_fiber = Fiber::ThreadToFiber();      thread_fibers[id] = thread_fiber;  }  void TestControl2::Exit() { -    std::thread::id this_id = std::this_thread::get_id(); -    u32 id = ids[this_id]; +    const u32 id = thread_ids.Get();      thread_fibers[id]->Exit();  } @@ -228,24 +242,21 @@ public:      void DoWork1() {          value1 += 1;          Fiber::YieldTo(fiber1, fiber2); -        std::thread::id this_id = std::this_thread::get_id(); -        u32 id = ids[this_id]; +        const u32 id = thread_ids.Get();          value3 += 1;          Fiber::YieldTo(fiber1, thread_fibers[id]);      }      void DoWork2() {          value2 += 1; -        std::thread::id this_id = std::this_thread::get_id(); -        u32 id = ids[this_id]; +        const u32 id = thread_ids.Get();          Fiber::YieldTo(fiber2, thread_fibers[id]);      }      void ExecuteThread(u32 id);      void CallFiber1() { -        std::thread::id this_id = std::this_thread::get_id(); -        u32 id = ids[this_id]; +        const u32 id = thread_ids.Get();          Fiber::YieldTo(thread_fibers[id], fiber1);      } @@ -254,7 +265,7 @@ public:      u32 value1{};      u32 value2{};      u32 value3{}; -    std::unordered_map<std::thread::id, u32> ids; +    ThreadIds thread_ids;      std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;      std::shared_ptr<Common::Fiber> fiber1;      std::shared_ptr<Common::Fiber> fiber2; @@ -271,15 +282,13 @@ static void WorkControl3_2(void* control) {  }  void TestControl3::ExecuteThread(u32 id) { -    std::thread::id this_id = std::this_thread::get_id(); -    ids[this_id] = id; +    thread_ids.Register(id);      auto thread_fiber = Fiber::ThreadToFiber();      thread_fibers[id] = thread_fiber;  }  void TestControl3::Exit() { -    std::thread::id this_id = std::this_thread::get_id(); -    u32 id = ids[this_id]; +    const u32 id = thread_ids.Get();      thread_fibers[id]->Exit();  }  | 
