1 // =================================
  2 // Copyright (c) 2021 Seppo Laakko
  3 // Distributed under the MIT license
  4 // =================================
  5 
  6 #include <cmajor/rt/Thread.hpp>
  7 #include <cmajor/rt/Io.hpp>
  8 #include <cmajor/rt/CallStack.hpp>
  9 #include <cmajor/rt/Error.hpp>
 10 #include <cmajor/rt/InitDone.hpp>
 11 #include <soulng/util/Error.hpp>
 12 #include <atomic>
 13 #include <mutex>
 14 #include <thread>
 15 #include <vector>
 16 #include <sstream>
 17 #include <unordered_map>
 18 
 19 namespace cmajor { namespace rt {
 20 
 21 typedef void(*ThreadFunction)();
 22 typedef void(*ThreadFunctionWithParam)(void*);
 23 typedef void(*ThreadMethod)(void*);
 24 typedef void(*ThreadMethodWithParam)(void*, void*);
 25 
 26 void ExecuteThreadFunction(ThreadFunction threadFunctionint32_t threadId)
 27 {
 28     try
 29     {
 30         threadFunction();
 31     }
 32     catch (...)
 33     {
 34         std::string str = "exception escaped from thread " + std::to_string(threadId) + "\n";
 35         int32_t errorStringHandle = -1;
 36         void* stdError = RtOpenStdFile(2errorStringHandle);
 37         RtWrite(stdErrorreinterpret_cast<const uint8_t*>(str.c_str())str.length()errorStringHandle);
 38         RtPrintCallStack(stdError);
 39         RtFlush(stdErrorerrorStringHandle);
 40         RtExit(exitCodeExceptionEscapedFromThread);
 41     }
 42 }
 43 
 44 void ExecuteThreadFunctionWithParam(ThreadFunctionWithParam threadFunctionvoid* paramint32_t threadId)
 45 {
 46     try
 47     {
 48         threadFunction(param);
 49     }
 50     catch (...)
 51     {
 52         std::string str = "exception escaped from thread " + std::to_string(threadId) + "\n";
 53         int32_t errorStringHandle = -1;
 54         void* stdError = RtOpenStdFile(2errorStringHandle);
 55         RtWrite(stdErrorreinterpret_cast<const uint8_t*>(str.c_str())str.length()errorStringHandle);
 56         RtPrintCallStack(stdError);
 57         RtFlush(stdErrorerrorStringHandle);
 58         RtExit(exitCodeExceptionEscapedFromThread);
 59     }
 60 }
 61 
 62 void ExecuteThreadMethod(ThreadMethod threadMethodvoid* objectint32_t threadId)
 63 {
 64     try
 65     {
 66         threadMethod(object);
 67     }
 68     catch (...)
 69     {
 70         std::string str = "exception escaped from thread " + std::to_string(threadId) + "\n";
 71         int32_t errorStringHandle = -1;
 72         void* stdError = RtOpenStdFile(2errorStringHandle);
 73         RtWrite(stdErrorreinterpret_cast<const uint8_t*>(str.c_str())str.length()errorStringHandle);
 74         RtPrintCallStack(stdError);
 75         RtFlush(stdErrorerrorStringHandle);
 76         RtExit(exitCodeExceptionEscapedFromThread);
 77     }
 78 }
 79 
 80 void ExecuteThreadMethodWithParam(ThreadMethodWithParam threadMethodvoid* objectvoid* paramint32_t threadId)
 81 {
 82     try
 83     {
 84         threadMethod(objectparam);
 85     }
 86     catch (...)
 87     {
 88         std::string str = "exception escaped from thread " + std::to_string(threadId) + "\n";
 89         int32_t errorStringHandle = -1;
 90         void* stdError = RtOpenStdFile(2errorStringHandle);
 91         RtWrite(stdErrorreinterpret_cast<const uint8_t*>(str.c_str())str.length()errorStringHandle);
 92         RtPrintCallStack(stdError);
 93         RtFlush(stdErrorerrorStringHandle);
 94         RtExit(exitCodeExceptionEscapedFromThread);
 95     }
 96 }
 97 
 98 class ThreadPool 
 99 {
100 public:
101     static void Init();
102     static void Done();
103     void Exit();
104     static ThreadPool& Instance() { Assert(instance"thread pool not initialized"); return *instance; }
105     int32_t StartThreadFunction(ThreadFunction fun);
106     int32_t StartThreadFunction(ThreadFunctionWithParam funvoid* param);
107     int32_t StartThreadMethod(ThreadMethod methodvoid* object);
108     int32_t StartThreadMethod(ThreadMethodWithParam methodvoid* objectvoid* param);
109     bool JoinThread(int32_t threadId);
110 private:
111     static std::unique_ptr<ThreadPool> instance;
112     const int32_t numNoLockThreads = 256;
113     std::atomic<int32_t> nextThreadId;
114     std::vector<std::std::unique_ptr<std::thread>>noLockThreads;
115     std::unordered_map<int32_tstd::std::unique_ptr<std::thread>>threadMap;
116     std::mutex mtx;
117     ThreadPool();
118 };
119 
120 void ThreadPool::Exit()
121 {
122     for (std::std::unique_ptr<std::thread>&t : noLockThreads)
123     {
124         if (t.get())
125         {
126             if (t->joinable())
127             {
128                 t->join();
129             }
130         }
131     }
132     for (auto& p : threadMap)
133     {
134         JoinThread(p.first);
135     }
136 }
137 
138 void ThreadPool::Init()
139 {
140     instance.reset(new ThreadPool());
141 }
142 
143 void ThreadPool::Done()
144 {
145     if (instance)
146     {
147         instance->Exit();
148     }
149     instance.reset();
150 }
151 
152 std::unique_ptr<ThreadPool> ThreadPool::instance;
153 
154 ThreadPool::ThreadPool() : nextThreadId(1)noLockThreads()
155 {
156     noLockThreads.resize(numNoLockThreads);
157 }
158 
159 int32_t ThreadPool::StartThreadFunction(ThreadFunction fun)
160 {
161     int32_t threadId = nextThreadId++;
162     if (threadId < numNoLockThreads)
163     {
164         noLockThreads[threadId].reset(new std::thread(ExecuteThreadFunctionfunthreadId));
165         return threadId;
166     }
167     else
168     {
169         std::lock_guard<std::mutex> lock(mtx);
170         threadMap[threadId].reset(new std::thread(ExecuteThreadFunctionfunthreadId));
171         return threadId;
172     }
173 }
174 
175 int32_t ThreadPool::StartThreadFunction(ThreadFunctionWithParam funvoid* param)
176 {
177     int32_t threadId = nextThreadId++;
178     if (threadId < numNoLockThreads)
179     {
180         noLockThreads[threadId].reset(new std::thread(ExecuteThreadFunctionWithParamfunparamthreadId));
181         return threadId;
182     }
183     else
184     {
185         std::lock_guard<std::mutex> lock(mtx);
186         threadMap[threadId].reset(new std::thread(ExecuteThreadFunctionWithParamfunparamthreadId));
187         return threadId;
188     }
189 }
190 
191 int32_t ThreadPool::StartThreadMethod(ThreadMethod methodvoid* object)
192 {
193     int32_t threadId = nextThreadId++;
194     if (threadId < numNoLockThreads)
195     {
196         noLockThreads[threadId].reset(new std::thread(ExecuteThreadMethodmethodobjectthreadId));
197         return threadId;
198     }
199     else
200     {
201         std::lock_guard<std::mutex> lock(mtx);
202         threadMap[threadId].reset(new std::thread(ExecuteThreadMethodmethodobjectthreadId));
203         return threadId;
204     }
205 }
206 
207 int32_t ThreadPool::StartThreadMethod(ThreadMethodWithParam methodvoid* objectvoid* param)
208 {
209     int32_t threadId = nextThreadId++;
210     if (threadId < numNoLockThreads)
211     {
212         noLockThreads[threadId].reset(new std::thread(ExecuteThreadMethodWithParammethodobjectparamthreadId));
213         return threadId;
214     }
215     else
216     {
217         std::lock_guard<std::mutex> lock(mtx);
218         threadMap[threadId].reset(new std::thread(ExecuteThreadMethodWithParammethodobjectparamthreadId));
219         return threadId;
220     }
221 }
222 
223 bool ThreadPool::JoinThread(int32_t threadId)
224 {
225     if (threadId < numNoLockThreads)
226     {
227         if (noLockThreads[threadId])
228         {
229             if (noLockThreads[threadId]->joinable())
230             {
231                 noLockThreads[threadId]->join();
232             }
233             noLockThreads[threadId].reset();
234             return true;
235         }
236     }
237     else
238     {
239         std::lock_guard<std::mutex> lock(mtx);
240         auto it = threadMap.find(threadId);
241         if (it != threadMap.cend())
242         {
243             std::thread* thread = it->second.get();
244             if (thread)
245             {
246                 if (thread->joinable())
247                 {
248                     thread->join();
249                 }
250                 threadMap.erase(threadId);
251                 return true;
252             }
253         }
254     }
255     return false;
256 }
257 
258 void InitThread()
259 {
260     ThreadPool::Init();
261 }
262 
263 void DoneThread()
264 {
265     ThreadPool::Done();
266 }
267 
268 } } // namespace cmajor::rt
269 
270 extern "C" int32_t RtGetHardwareConcurrency()
271 {
272     return std::thread::hardware_concurrency();
273 }
274 
275 extern "C" int32_t RtStartThreadFunction(void* function)
276 {
277     cmajor::rt::ThreadFunction threadFun = reinterpret_cast<cmajor::rt::ThreadFunction>(function);
278     return cmajor::rt::ThreadPool::Instance().StartThreadFunction(threadFun);
279 }
280 
281 extern "C" int32_t RtStartThreadFunctionWithParam(void* function, void* param)
282 {
283     cmajor::rt::ThreadFunctionWithParam threadFunWithParam = reinterpret_cast<cmajor::rt::ThreadFunctionWithParam>(function);
284     return cmajor::rt::ThreadPool::Instance().StartThreadFunction(threadFunWithParam, param);
285 }
286 
287 struct ClassDelegate 
288 {
289     void* object;
290     void* method;
291 };
292 
293 extern "C" int32_t RtStartThreadMethod(void* classDelegate)
294 {
295     ClassDelegate* clsDlg = reinterpret_cast<ClassDelegate*>(classDelegate);
296     cmajor::rt::ThreadMethod threadMethod = reinterpret_cast<cmajor::rt::ThreadMethod>(clsDlg->method);
297     return cmajor::rt::ThreadPool::Instance().StartThreadMethod(threadMethod, clsDlg->object);
298 }
299 
300 extern "C" int32_t RtStartThreadMethodWithParam(void* classDelegate, void* param)
301 {
302     ClassDelegate* clsDlg = reinterpret_cast<ClassDelegate*>(classDelegate);
303     cmajor::rt::ThreadMethodWithParam threadMethodWithParam = reinterpret_cast<cmajor::rt::ThreadMethodWithParam>(clsDlg->method);
304     return cmajor::rt::ThreadPool::Instance().StartThreadMethod(threadMethodWithParam, clsDlg->object, param);
305 }
306 
307 extern "C" bool RtJoinThread(int32_t threadId)
308 {
309     return cmajor::rt::ThreadPool::Instance().JoinThread(threadId);
310 }
311 
312 std::unordered_map<std::thread::idint> threadIdMap;
313 
314 int nextThreadId = 0;
315 std::mutex threadIdMapMutex;
316 
317 extern "C" int32_t RtThisThreadId()
318 {
319     std::lock_guard<std::mutex> lock(threadIdMapMutex);
320     std::thread::id id = std::this_thread::get_id();
321     auto it = threadIdMap.find(id);
322     if (it != threadIdMap.cend())
323     {
324         return it->second;
325     }
326     int threadId = nextThreadId++;
327     threadIdMap[id] = threadId;
328     return threadId;
329 }