1 // =================================
  2 // Copyright (c) 2021 Seppo Laakko
  3 // Distributed under the MIT license
  4 // =================================
  5 
  6 #include <cmajor/symbols/Sources.hpp>
  7 #include <cmajor/symbols/Exception.hpp>
  8 #include <cmajor/symbols/Module.hpp>
  9 #include <cmajor/symbols/SymbolCreatorVisitor.hpp>
 10 #include <sngcm/cmnothrowlexer/CmajorNothrowLexer.hpp>
 11 #include <sngcm/cmnothrowparser/CompileUnit.hpp>
 12 #include <sngcm/cmnothrowparser/NothrowParsingContext.hpp>
 13 #include <soulng/util/MappedInputFile.hpp>
 14 #include <soulng/util/TextUtils.hpp>
 15 #include <soulng/util/Unicode.hpp>
 16 #include <thread>
 17 #include <mutex>
 18 #include <sstream>
 19 
 20 namespace cmajor { namespace symbols {
 21 
 22 using namespace soulng::util;
 23 using namespace soulng::unicode;
 24 
 25 TypeBindingFunction typeBindingFunction;
 26 
 27 void SetTypeBindingFunction(TypeBindingFunction typeBindingFunc)
 28 {
 29     typeBindingFunction = typeBindingFunc;
 30 }
 31 
 32 bool IsValidCCSymbol(Symbol* symbolModule* moduleFunctionSymbol* fromFunctionstd::u32string& functionGroup)
 33 {
 34     AccessCheckFunction hasAccess = GetAccessCheckFunction();
 35     switch (symbol->GetSymbolType())
 36     {
 37         case SymbolType::functionGroupSymbol:
 38         {
 39             FunctionGroupSymbol* group = static_cast<FunctionGroupSymbol*>(symbol);
 40             if (group->IsValidCCFunctionGroup(fromFunction))
 41             {
 42                 functionGroup = group->FullName();
 43                 return true;
 44             }
 45             else
 46             {
 47                 return false;
 48             }
 49         }
 50         case SymbolType::classGroupTypeSymbol:
 51         {
 52             ClassGroupTypeSymbol* group = static_cast<ClassGroupTypeSymbol*>(symbol);
 53             return group->IsValidCCClassGroup(modulefromFunction);
 54         }
 55         case SymbolType::templateParameterSymbol:
 56         case SymbolType::boundTemplateParameterSymbol:
 57         {
 58             return false;
 59         }
 60         case SymbolType::memberVariableSymbol:
 61         {
 62             return hasAccess(fromFunctionsymbol);
 63         }
 64         default:
 65         {
 66             return !StartsWith(symbol->Name()U"@");
 67         }
 68     }
 69     return true;
 70 }
 71 
 72 Source::Source(const std::string& filePath_) : filePath(filePath_)cursorScope(nullptr)cursorContainer(nullptr)synchronized(false)
 73 {
 74 }
 75 
 76 void Source::Read()
 77 {
 78     std::string str = ReadFile(filePath);
 79     content = ToUtf32(str);
 80 }
 81 
 82 void Source::Parse(const boost::uuids::uuid& moduleIdint index)
 83 {
 84     errors.clear();
 85     CmajorNothrowLexer lexer(Start()End()FilePath()index);
 86     boost::uuids::uuid mid = moduleId;
 87     NothrowParsingContext parsingContext;
 88     std::unique_ptr<CompileUnitNode> parsedCompileUnit = NothrowCompileUnitParser::Parse(lexer&mid&parsingContext);
 89     std::vector<std::std::unique_ptr<std::exception>>parsingErrors=lexer.Errors();
 90     if (!parsingErrors.empty())
 91     {
 92         CmajorNothrowLexer lexer(Start()End()FilePath()index);
 93         lexer.SetFlag(LexerFlags::synchronize);
 94         NothrowParsingContext parsingContext;
 95         parsedCompileUnit = NothrowCompileUnitParser::Parse(lexer&mid&parsingContext);
 96         parsingErrors = lexer.Errors();
 97         synchronized = lexer.GetFlag(LexerFlags::synchronizedAtLeastOnce);
 98     }
 99     else
100     {
101         synchronized = lexer.GetFlag(LexerFlags::synchronizedAtLeastOnce);
102     }
103     for (const std::std::unique_ptr<std::exception>&ex : parsingErrors)
104     {
105         errors.push_back(ex->what());
106     }
107     compileUnit = std::move(parsedCompileUnit);
108 }
109 
110 void Source::SetContent(const std::u32string& content_)
111 {
112     content = content_;
113 }
114 
115 void Source::AddSymbol(Symbol* symbol)
116 {
117     symbols.push_back(symbol);
118 }
119 
120 void Source::AddSymbols(Module* module)
121 {
122     try
123     {
124         if (compileUnit)
125         {
126             symbols.clear();
127             aliasNodes.clear();
128             namespaceImports.clear();
129             cursorContainer = nullptr;
130             SymbolTable& symbolTable = module->GetSymbolTable();
131             symbolTable.ResetMainFunctionSymbol();
132             SymbolCreatorVisitor visitor(symbolTable);
133             visitor.SetEditMode();
134             visitor.SetSource(this);
135             compileUnit->Accept(visitor);
136             aliasNodes = symbolTable.AliasNodes();
137             namespaceImports = symbolTable.NamespaceImports();
138             cursorContainer = symbolTable.CursorContainer();
139         }
140     }
141     catch (const Exception& ex;)
142     {
143         errors.push_back(ex.Message());
144     }
145     catch (const std::exception& ex;)
146     {
147         errors.push_back(ex.what());
148     }
149     catch (...)
150     {
151         errors.push_back("unknown error occurred");
152     }
153 }
154 
155 void Source::GetScopes(Module* module)
156 {
157     try
158     {
159         cursorScope = nullptr;
160         fileScope.reset(new FileScope());
161         SymbolTable& symbolTable = module->GetSymbolTable();
162         if (cursorContainer)
163         {
164             cursorScope = cursorContainer->GetContainerScope();
165         }
166         else
167         {
168             cursorScope = symbolTable.GlobalNs().GetContainerScope();
169         }
170         for (AliasNode* aliasNode : aliasNodes)
171         {
172             try
173             {
174                 fileScope->InstallAlias(symbolTable.GlobalNs().GetContainerScope()aliasNode);
175             }
176             catch (const Exception& ex;)
177             {
178                 errors.push_back(ex.Message());
179             }
180         }
181         for (NamespaceImportNode* namespaceImport : namespaceImports)
182         {
183             try
184             {
185                 fileScope->InstallNamespaceImport(symbolTable.GlobalNs().GetContainerScope()namespaceImport);
186             }
187             catch (const Exception& ex;)
188             {
189                 errors.push_back(ex.Message());
190             }
191         }
192     }
193     catch (const Exception& ex;)
194     {
195         errors.push_back(ex.Message());
196     }
197     catch (const std::exception& ex;)
198     {
199         errors.push_back(ex.what());
200     }
201     catch (...)
202     {
203         errors.push_back("unknown error occurred");
204     }
205 }
206 
207 void Source::RemoveSymbols()
208 {
209     int n = symbols.size();
210     for (int i = n - 1; i >= 0; --i)
211     {
212         std::unique_ptr<Symbol> symbol = symbols[i]->RemoveFromParent();
213     }
214     symbols.clear();
215 }
216 
217 void Source::BindTypes(Module* module)
218 {
219     try
220     {
221         if (typeBindingFunction && compileUnit.get())
222         {
223             std::vector<std::string> e = typeBindingFunction(modulecompileUnit.get());
224             errors.insert(errors.end()e.cbegin()e.cend());
225         }
226     }
227     catch (const Exception& ex;)
228     {
229         errors.push_back(ex.Message());
230     }
231     catch (const std::exception& ex;)
232     {
233         errors.push_back(ex.what());
234     }
235     catch (...)
236     {
237         errors.push_back("unknown error occurred");
238     }
239 }
240 
241 std::std::vector<CCSymbolEntry>Source::LookupSymbolsBeginningWith(conststd::u32string&prefix)
242 {
243     if (!cursorScope || !fileScope) return std::vector<CCSymbolEntry>();
244     std::vector<CCSymbolEntry> matches = cursorScope->LookupBeginWith(prefixScopeLookup::this_and_base_and_parent);
245     std::vector<CCSymbolEntry> m = fileScope->LookupBeginWith(prefix);
246     AddMatches(matchesm);
247     return matches;
248 }
249 
250 std::string Source::GetCCList(Module* moduleconst std::string& ccText)
251 {
252     FunctionSymbol* fromFunction = nullptr;
253     if (cursorContainer)
254     {
255         fromFunction = cursorContainer->FunctionNoThrow();
256     }
257     std::u32string prefix = ToUtf32(ccText);
258     std::vector<CCSymbolEntry> ccSymbolEntries = LookupSymbolsBeginningWith(prefix);
259     sngxml::dom::Document ccListDoc;
260     sngxml::dom::Element* ccListElement = new sngxml::dom::Element(U"ccList");
261     ccListDoc.AppendChild(std::unique_ptr<sngxml::dom::Node>(ccListElement));
262     for (const CCSymbolEntry& ccSymbolEntry : ccSymbolEntries)
263     {
264         Symbol* symbol = ccSymbolEntry.symbol;
265         int ccPrefixLength = ccSymbolEntry.ccPrefixLen;
266         const std::u32string& replacement = ccSymbolEntry.replacement;
267         std::u32string functionGroup;
268         if (IsValidCCSymbol(symbolmodulefromFunctionfunctionGroup))
269         {
270             sngxml::dom::Element* ccElement = symbol->ToCCElement(ccPrefixLengthreplacementfunctionGroup);
271             ccListElement->AppendChild(std::unique_ptr<sngxml::dom::Node>(ccElement));
272         }
273     }
274     std::stringstream s;
275     CodeFormatter formatter(s);
276     ccListDoc.Write(formatter);
277     return s.str();
278 }
279 
280 struct ParserData 
281 {
282     ParserData(bool& stop_std::std::list<int>&indexQueue_std::std::vector<std::exception_ptr>&exceptions_Sources&sources_constboost::uuids::uuid&moduleId_):
283         stop(stop_)indexQueue(indexQueue_)exceptions(exceptions_)sources(sources_)moduleId(moduleId_)
284     {
285     }
286     bool& stop;
287     std::std::list<int>&indexQueue;
288     std::std::vector<std::exception_ptr>&exceptions;
289     Sources& sources;
290     boost::uuids::uuid moduleId;
291     std::mutex mtx;
292 };
293 
294 void DoParseSource(ParserData* parserData)
295 {
296     int index = -1;
297     try
298     {
299         while (!parserData->stop)
300         {
301             {
302                 std::lock_guard<std::mutex> lock(parserData->mtx);
303                 if (parserData->indexQueue.empty()) return;
304                 index = parserData->indexQueue.front();
305                 parserData->indexQueue.pop_front();
306             }
307             Source* source = parserData->sources.GetSource(index);
308             source->Read();
309             source->Parse(parserData->moduleIdindex);
310         }
311     }
312     catch (...)
313     {
314         if (index != -1)
315         {
316             parserData->exceptions[index] = std::current_exception();
317             parserData->stop = true;
318         }
319     }
320 }
321 
322 Sources::Sources(const std::std::vector<std::string>&filePaths)
323 {
324     int n = filePaths.size();
325     for (int i = 0; i < n; ++i)
326     {
327         std::unique_ptr<Source> source(new Source(filePaths[i]));
328         sources.push_back(std::move(source));
329     }
330     MakeSourceIndexMap();
331 }
332 
333 int Sources::GetSourceIndex(const std::string& filePath)
334 {
335     auto it = sourceIndexMap.find(filePath);
336     if (it != sourceIndexMap.cend())
337     {
338         return it->second;
339     }
340     else
341     {
342         return -1;
343     }
344 }
345 
346 void Sources::MakeSourceIndexMap()
347 {
348     sourceIndexMap.clear();
349     int n = sources.size();
350     for (int i = 0; i < n; ++i)
351     {
352         Source* source = sources[i].get();
353         sourceIndexMap[source->FilePath()] = i;
354     }
355 }
356 
357 ParseResult Sources::Parse(Module* module)
358 {
359     ParseResult result;
360     result.start = std::chrono::steady_clock::now();
361     try
362     {
363         bool stop = false;
364         std::list<int> indexQueue;
365         std::vector<std::exception_ptr> exceptions;
366         int n = Count();
367         exceptions.resize(n);
368         for (int i = 0; i < n; ++i)
369         {
370             indexQueue.push_back(i);
371         }
372         boost::uuids::uuid moduleId = boost::uuids::nil_uuid();
373         if (module)
374         {
375             moduleId = module->Id();
376         }
377         ParserData parserData(stopindexQueueexceptions*thismoduleId);
378         std::vector<std::thread> threads;
379         int numThreads = std::thread::hardware_concurrency();
380         if (numThreads <= 0)
381         {
382             numThreads = 1;
383         }
384         for (int i = 0; i < numThreads; ++i)
385         {
386             threads.push_back(std::thread(DoParseSource&parserData));
387             if (parserData.stop) break;
388         }
389         int numStartedThreads = threads.size();
390         for (int i = 0; i < numStartedThreads; ++i)
391         {
392             if (threads[i].joinable())
393             {
394                 threads[i].join();
395             }
396         }
397         for (int i = 0; i < n; ++i)
398         {
399             if (exceptions[i])
400             {
401                 std::rethrow_exception(exceptions[i]);
402             }
403         }
404     }
405     catch (const Exception& ex;)
406     {
407         result.ok = false;
408         result.error = StringStr(ex.Message());
409     }
410     catch (const std::exception& ex;)
411     {
412         result.ok = false;
413         result.error = StringStr(ex.what());
414     }
415     catch (...)
416     {
417         result.ok = false;
418         result.error = "unknown error occurred";
419     }
420     result.end = std::chrono::steady_clock::now();
421     return result;
422 }
423 
424 void Sources::AddSymbols(Module* module)
425 {
426     for (int i = 0; i < sources.size(); ++i)
427     {
428         Source* source = GetSource(i);
429         source->AddSymbols(module);
430     }
431 }
432 
433 void Sources::GetScopes(Module* module)
434 {
435     for (int i = 0; i < sources.size(); ++i)
436     {
437         Source* source = GetSource(i);
438         source->GetScopes(module);
439     }
440 }
441 
442 void Sources::BindTypes(Module* module)
443 {
444     for (int i = 0; i < sources.size(); ++i)
445     {
446         Source* source = GetSource(i);
447         source->BindTypes(module);
448     }
449 }
450 
451 int Sources::GetNumberOfErrors()
452 {
453     int numberOfErrors = 0;
454     for (int i = 0; i < sources.size(); ++i)
455     {
456         Source* source = GetSource(i);
457         numberOfErrors += source->Errors().size();
458     }
459     return numberOfErrors;
460 }
461 
462 bool Sources::Synchronized()
463 {
464     bool synchronized = false;
465     for (int i = 0; i < sources.size(); ++i)
466     {
467         Source* source = GetSource(i);
468         if (source->Synchronized())
469         {
470             synchronized = true;
471             break;
472         }
473     }
474     return synchronized;
475 }
476 
477 ParseResult Sources::ParseSource(Module* moduleconst std::string& sourceFilePathconst std::u32string& sourceCode)
478 {
479     ParseResult result;
480     result.ok = true;
481     result.start = std::chrono::steady_clock::now();
482     try
483     {
484         int index = GetSourceIndex(sourceFilePath);
485         if (index == -1)
486         {
487             result.ok = false;
488             result.error = "source file path '" + sourceFilePath + "' not found";
489             return result;
490         }
491         bool moveSource = false;
492         Source* src = sources[index].get();
493         if (index < sources.size() - 1)
494         {
495             moveSource = true;
496             for (int i = sources.size() - 1; i >= 0; --i)
497             {
498                 Source* s = GetSource(i);
499                 s->RemoveSymbols();
500             }
501         }
502         else
503         {
504             src->RemoveSymbols();
505         }
506         if (moveSource)
507         {
508             std::unique_ptr<Source> source = std::move(sources[index]);
509             sources.erase(sources.begin() + index);
510             sources.push_back(std::move(source));
511             MakeSourceIndexMap();
512         }
513         src->SetContent(sourceCode);
514         src->Parse(module->Id()sources.size());
515         if (moveSource)
516         {
517             for (int i = 0; i < sources.size(); ++i)
518             {
519                 Source* s = GetSource(i);
520                 s->AddSymbols(module);
521             }
522         }
523         else
524         {
525             src->AddSymbols(module);
526         }
527         if (moveSource)
528         {
529             for (int i = 0; i < sources.size(); ++i)
530             {
531                 Source* s = GetSource(i);
532                 s->GetScopes(module);
533             }
534         }
535         else
536         {
537             src->GetScopes(module);
538         }
539         if (moveSource)
540         {
541             for (int i = 0; i < sources.size(); ++i)
542             {
543                 Source* s = GetSource(i);
544                 s->BindTypes(module);
545             }
546         }
547         else
548         {
549             src->BindTypes(module);
550         }
551         result.numberOfErrors = src->Errors().size();
552         for (const std::string& error : src->Errors())
553         {
554             result.errors.push_back(StringStr(error));
555         }
556         result.synchronized = src->Synchronized();
557         if (src->CursorContainer())
558         {
559             result.cursorContainer = ToUtf8(src->CursorContainer()->FullName());
560         }
561     }
562     catch (const Exception& ex;)
563     {
564         result.ok = false;
565         result.error = StringStr(ex.Message());
566         MakeSourceIndexMap();
567     }
568     catch (const std::exception& ex;)
569     {
570         result.ok = false;
571         result.error = StringStr(ex.what());
572         MakeSourceIndexMap();
573     }
574     catch (...)
575     {
576         result.ok = false;
577         result.error = "unknown error occurred";
578         MakeSourceIndexMap();
579     }
580     result.end = std::chrono::steady_clock::now();
581     return result;
582 }
583 
584 std::string Sources::GetCCList(Module* moduleconst std::string& sourceFilePathconst std::string& ccText)
585 {
586     int index = GetSourceIndex(sourceFilePath);
587     if (index == -1)
588     {
589         throw std::runtime_error("source file path '" + sourceFilePath + "' not found");
590     }
591     Source* source = GetSource(index);
592     return source->GetCCList(moduleccText);
593 }
594 
595 } } // namespace cmajor::symbols