1 // =================================
  2 // Copyright (c) 2021 Seppo Laakko
  3 // Distributed under the MIT license
  4 // =================================
  5 
  6 #include <cmajor/binder/ClassTemplateRepository.hpp>
  7 #include <cmajor/binder/BoundCompileUnit.hpp>
  8 #include <cmajor/binder/TypeResolver.hpp>
  9 #include <cmajor/binder/TypeBinder.hpp>
 10 #include <cmajor/binder/StatementBinder.hpp>
 11 #include <cmajor/binder/BoundClass.hpp>
 12 #include <cmajor/binder/BoundFunction.hpp>
 13 #include <cmajor/binder/BoundStatement.hpp>
 14 #include <cmajor/binder/Concept.hpp>
 15 #include <cmajor/symbols/TemplateSymbol.hpp>
 16 #include <cmajor/symbols/SymbolCreatorVisitor.hpp>
 17 #include <sngcm/ast/Identifier.hpp>
 18 #include <soulng/util/Util.hpp>
 19 #include <soulng/util/Unicode.hpp>
 20 
 21 namespace cmajor { namespace binder {
 22 
 23 using namespace soulng::util;
 24 using namespace soulng::unicode;
 25 
 26 size_t ClassIdMemberFunctionIndexHash::operator()(const std::std::pair<boost::uuids::uuidint>&p) const
 27 {
 28     return boost::hash<boost::uuids::uuid>()(p.first) ^ boost::hash<int>()(p.second);
 29 }
 30 
 31 ClassTemplateRepository::ClassTemplateRepository(BoundCompileUnit& boundCompileUnit_) : boundCompileUnit(boundCompileUnit_)
 32 {
 33 }
 34 
 35 void ClassTemplateRepository::ResolveDefaultTemplateArguments(std::std::vector<TypeSymbol*>&templateArgumentTypesClassTypeSymbol*classTemplateContainerScope*containerScope
 36     const Span& spanconst boost::uuids::uuid& moduleId)
 37 {
 38     int n = classTemplate->TemplateParameters().size();
 39     int m = templateArgumentTypes.size();
 40     if (m == n) return;
 41     SymbolTable& symbolTable = boundCompileUnit.GetSymbolTable();
 42     Node* node = symbolTable.GetNodeNoThrow(classTemplate);
 43     if (!node)
 44     {
 45         node = classTemplate->GetClassNode();
 46         Assert(node"class node not read");
 47     }
 48     Assert(node->GetNodeType() == NodeType::classNode"class node expected");
 49     ClassNode* classNode = static_cast<ClassNode*>(node);
 50     int numFileScopeAdded = 0;
 51     int nu = classTemplate->UsingNodes().Count();
 52     if (nu > 0)
 53     {
 54         FileScope* fileScope = new FileScope();
 55         for (int i = 0; i < nu; ++i)
 56         {
 57             Node* usingNode = classTemplate->UsingNodes()[i];
 58             if (usingNode->GetNodeType() == NodeType::namespaceImportNode)
 59             {
 60                 NamespaceImportNode* namespaceImportNode = static_cast<NamespaceImportNode*>(usingNode);
 61                 fileScope->InstallNamespaceImport(containerScopenamespaceImportNode);
 62             }
 63             else if (usingNode->GetNodeType() == NodeType::aliasNode)
 64             {
 65                 AliasNode* aliasNode = static_cast<AliasNode*>(usingNode);
 66                 fileScope->InstallAlias(containerScopealiasNode);
 67             }
 68         }
 69         boundCompileUnit.AddFileScope(fileScope);
 70         ++numFileScopeAdded;
 71     }
 72     if (!classTemplate->Ns()->IsGlobalNamespace())
 73     {
 74         FileScope* primaryFileScope = new FileScope();
 75         primaryFileScope->AddContainerScope(classTemplate->Ns()->GetContainerScope());
 76         boundCompileUnit.AddFileScope(primaryFileScope);
 77         ++numFileScopeAdded;
 78     }
 79     ContainerScope resolveScope;
 80     resolveScope.SetParentScope(containerScope);
 81     std::vector<std::std::unique_ptr<BoundTemplateParameterSymbol>>boundTemplateParameters;
 82     for (int i = 0; i < n; ++i)
 83     {
 84         TemplateParameterSymbol* templateParameterSymbol = classTemplate->TemplateParameters()[i];
 85         BoundTemplateParameterSymbol* boundTemplateParameter = new BoundTemplateParameterSymbol(spanmoduleIdtemplateParameterSymbol->Name());
 86         boundTemplateParameters.push_back(std::unique_ptr<BoundTemplateParameterSymbol>(boundTemplateParameter));
 87         if (i < m)
 88         {
 89             boundTemplateParameter->SetType(templateArgumentTypes[i]);
 90             resolveScope.Install(boundTemplateParameter);
 91         }
 92         else
 93         {
 94             if (i >= classNode->TemplateParameters().Count())
 95             {
 96                 throw Exception("too few template arguments"spanmoduleId);
 97             }
 98             Node* defaultTemplateArgumentNode = classNode->TemplateParameters()[i]->DefaultTemplateArgument();
 99             if (!defaultTemplateArgumentNode)
100             {
101                 throw Exception("too few template arguments"spanmoduleId);
102             }
103             TypeSymbol* templateArgumentType = ResolveType(defaultTemplateArgumentNodeboundCompileUnit&resolveScope);
104             templateArgumentTypes.push_back(templateArgumentType);
105         }
106     }
107     for (int i = 0; i < numFileScopeAdded; ++i)
108     {
109         boundCompileUnit.RemoveLastFileScope();
110     }
111 }
112 
113 void ClassTemplateRepository::BindClassTemplateSpecialization(ClassTemplateSpecializationSymbol* classTemplateSpecializationContainerScope* containerScope
114     const Span& span const boost::uuids::uuid& moduleId)
115 {
116     if (classTemplateSpecialization->IsBound()) return;
117     if (classTemplateSpecialization->FullName() == U"String<char>")
118     {
119         int x = 0;
120     }
121     SymbolTable& symbolTable = boundCompileUnit.GetSymbolTable();
122     ClassTypeSymbol* classTemplate = classTemplateSpecialization->GetClassTemplate();
123     Node* node = symbolTable.GetNodeNoThrow(classTemplate);
124     if (!node)
125     {
126         node = classTemplate->GetClassNode();
127         Assert(node"class node not read");
128     }
129     Assert(node->GetNodeType() == NodeType::classNode"class node expected");
130     ClassNode* classNode = static_cast<ClassNode*>(node);
131     std::unique_ptr<NamespaceNode> globalNs(new NamespaceNode(classNode->GetSpan()classNode->ModuleId()new IdentifierNode(classNode->GetSpan()classNode->ModuleId()U"")));
132     NamespaceNode* currentNs = globalNs.get();
133     CloneContext cloneContext;
134     cloneContext.SetInstantiateClassNode();
135     int nu = classTemplate->UsingNodes().Count();
136     for (int i = 0; i < nu; ++i)
137     {
138         Node* usingNode = classTemplate->UsingNodes()[i];
139         globalNs->AddMember(usingNode->Clone(cloneContext));
140     }
141     bool fileScopeAdded = false;
142     if (!classTemplate->Ns()->IsGlobalNamespace())
143     {
144         FileScope* primaryFileScope = new FileScope();
145         primaryFileScope->AddContainerScope(classTemplate->Ns()->GetContainerScope());
146         boundCompileUnit.AddFileScope(primaryFileScope);
147         fileScopeAdded = true;
148         std::u32string fullNsName = classTemplate->Ns()->FullName();
149         std::vector<std::u32string> nsComponents = Split(fullNsName'.');
150         for (const std::u32string& nsComponent : nsComponents)
151         {
152             NamespaceNode* nsNode = new NamespaceNode(classNode->GetSpan()classNode->ModuleId()new IdentifierNode(classNode->GetSpan()classNode->ModuleId()nsComponent));
153             currentNs->AddMember(nsNode);
154             currentNs = nsNode;
155         }
156     }
157     ClassNode* classInstanceNode = static_cast<ClassNode*>(classNode->Clone(cloneContext));
158     currentNs->AddMember(classInstanceNode);
159     int n = classTemplate->TemplateParameters().size();
160     int m = classTemplateSpecialization->TemplateArgumentTypes().size();
161     if (n != m)
162     {
163         throw Exception("wrong number of template arguments"spanmoduleId);
164     }
165     bool templateParameterBinding = false;
166     ContainerScope resolveScope;
167     resolveScope.SetParentScope(containerScope);
168     for (int i = 0; i < n; ++i)
169     {
170         TemplateParameterSymbol* templateParameter = classTemplate->TemplateParameters()[i];
171         BoundTemplateParameterSymbol* boundTemplateParameter = new BoundTemplateParameterSymbol(spanmoduleIdtemplateParameter->Name());
172         boundTemplateParameter->SetParent(classTemplateSpecialization);
173         TypeSymbol* templateArgumentType = classTemplateSpecialization->TemplateArgumentTypes()[i];
174         boundTemplateParameter->SetType(templateArgumentType);
175         if (templateArgumentType->GetSymbolType() == SymbolType::templateParameterSymbol)
176         {
177             templateParameterBinding = true;
178             if (classTemplateSpecialization->IsPrototype())
179             {
180                 if (classTemplateSpecialization->IsProject())
181                 {
182                     resolveScope.Install(boundTemplateParameter);
183                     TemplateParameterNode* templateParameterNode = classNode->TemplateParameters()[i];
184                     Node* defaultTemplateArgumentNode = templateParameterNode->DefaultTemplateArgument();
185                     if (defaultTemplateArgumentNode)
186                     {
187                         TypeSymbol* templateArgumentType = ResolveType(defaultTemplateArgumentNodeboundCompileUnit&resolveScope);
188                         templateParameter->SetDefaultType(templateArgumentType);
189                     }
190                 }
191             }
192         }
193         classTemplateSpecialization->AddMember(boundTemplateParameter);
194     }
195     symbolTable.SetCurrentCompileUnit(boundCompileUnit.GetCompileUnitNode());
196     SymbolCreatorVisitor symbolCreatorVisitor(symbolTable);
197     symbolCreatorVisitor.SetClassInstanceNode(classInstanceNode);
198     symbolCreatorVisitor.SetClassTemplateSpecialization(classTemplateSpecialization);
199     globalNs->Accept(symbolCreatorVisitor);
200     TypeBinder typeBinder(boundCompileUnit);
201     if (templateParameterBinding)
202     {
203         typeBinder.CreateMemberSymbols();
204     }
205     typeBinder.SetContainerScope(classTemplateSpecialization->GetContainerScope());
206     globalNs->Accept(typeBinder);
207     if (templateParameterBinding)
208     {
209         classTemplateSpecialization->SetGlobalNs(std::move(globalNs));
210         if (fileScopeAdded)
211         {
212             boundCompileUnit.RemoveLastFileScope();
213         }
214     }
215     else if (boundCompileUnit.BindingTypes())
216     {
217         classTemplateSpecialization->SetGlobalNs(std::move(globalNs));
218         classTemplateSpecialization->SetStatementsNotBound();
219         if (fileScopeAdded)
220         {
221             FileScope* fileScope = boundCompileUnit.ReleaseLastFileScope();
222             classTemplateSpecialization->SetFileScope(fileScope);
223         }
224     }
225     else
226     {
227         StatementBinder statementBinder(boundCompileUnit);
228         globalNs->Accept(statementBinder);
229         classTemplateSpecialization->SetGlobalNs(std::move(globalNs));
230         if (fileScopeAdded)
231         {
232             boundCompileUnit.RemoveLastFileScope();
233         }
234     }
235 }
236 
237 bool ClassTemplateRepository::Instantiate(FunctionSymbol* memberFunctionContainerScope* containerScopeBoundFunction* currentFunctionconst Span& spanconst boost::uuids::uuid& moduleId)
238 {
239     if (instantiatedMemberFunctions.find(memberFunction) != instantiatedMemberFunctions.cend()) return true;
240     instantiatedMemberFunctions.insert(memberFunction);
241     try
242     {
243         SymbolTable& symbolTable = boundCompileUnit.GetSymbolTable();
244         Symbol* parent = memberFunction->Parent();
245         Assert(parent->GetSymbolType() == SymbolType::classTemplateSpecializationSymbol"class template specialization expected");
246         ClassTemplateSpecializationSymbol* classTemplateSpecialization = static_cast<ClassTemplateSpecializationSymbol*>(parent);
247         std::pair<boost::uuids::uuidint> classIdMemFunIndexPair = std::make_pair(classTemplateSpecialization->TypeId()memberFunction->GetIndex());
248         if (classIdMemberFunctionIndexSet.find(classIdMemFunIndexPair) != classIdMemberFunctionIndexSet.cend())
249         {
250 //          If <parent class id, member function index> pair is found from the classIdMemberFunctionIndexSet, the member function is already instantiated 
251 //          for this compile unit, so return true.
252             instantiatedMemberFunctions.insert(memberFunction);
253             return true;
254         }
255         Assert(classTemplateSpecialization->IsBound()"class template specialization not bound");
256         Node* node = symbolTable.GetNodeNoThrow(memberFunction);
257         if (!node)
258         {
259             return false;
260         }
261         boundCompileUnit.FinalizeBinding(classTemplateSpecialization);
262         ClassTypeSymbol* classTemplate = classTemplateSpecialization->GetClassTemplate();
263         std::unordered_map<TemplateParameterSymbol*TypeSymbol*> templateParameterMap;
264         int n = classTemplateSpecialization->TemplateArgumentTypes().size();
265         for (int i = 0; i < n; ++i)
266         {
267             TemplateParameterSymbol* templateParameter = classTemplate->TemplateParameters()[i];
268             TypeSymbol* templateArgument = classTemplateSpecialization->TemplateArgumentTypes()[i];
269             templateParameterMap[templateParameter] = templateArgument;
270         }
271         if (!classTemplateSpecialization->IsConstraintChecked())
272         {
273             classTemplateSpecialization->SetConstraintChecked();
274             if (classTemplate->Constraint())
275             {
276                 std::unique_ptr<BoundConstraint> boundConstraint;
277                 std::unique_ptr<Exception> conceptCheckException;
278                 if (!CheckConstraint(classTemplate->Constraint()classTemplate->UsingNodes()boundCompileUnitcontainerScopecurrentFunctionclassTemplate->TemplateParameters()
279                     templateParameterMapboundConstraintspanmoduleIdmemberFunctionconceptCheckException))
280                 {
281                     if (conceptCheckException)
282                     {
283                         throw Exception("concept check of class template specialization '" + ToUtf8(classTemplateSpecialization->FullName()) + "' failed: " + conceptCheckException->Message()span
284                             moduleIdconceptCheckException->References());
285                     }
286                     else
287                     {
288                         throw Exception("concept check of class template specialization '" + ToUtf8(classTemplateSpecialization->FullName()) + "' failed."spanmoduleId);
289                     }
290                 }
291             }
292         }
293         FileScope* fileScope = new FileScope();
294         int nu = classTemplate->UsingNodes().Count();
295         for (int i = 0; i < nu; ++i)
296         {
297             Node* usingNode = classTemplate->UsingNodes()[i];
298             if (usingNode->GetNodeType() == NodeType::namespaceImportNode)
299             {
300                 NamespaceImportNode* namespaceImportNode = static_cast<NamespaceImportNode*>(usingNode);
301                 fileScope->InstallNamespaceImport(containerScopenamespaceImportNode);
302             }
303             else if (usingNode->GetNodeType() == NodeType::aliasNode)
304             {
305                 AliasNode* aliasNode = static_cast<AliasNode*>(usingNode);
306                 fileScope->InstallAlias(containerScopealiasNode);
307             }
308         }
309         if (!classTemplate->Ns()->IsGlobalNamespace())
310         {
311             fileScope->AddContainerScope(classTemplate->Ns()->GetContainerScope());
312         }
313         boundCompileUnit.AddFileScope(fileScope);
314         Assert(node->IsFunctionNode()"function node expected");
315         FunctionNode* functionInstanceNode = static_cast<FunctionNode*>(node);
316         if (memberFunction->IsDefault())
317         {
318             functionInstanceNode->SetBodySource(new sngcm::ast::CompoundStatementNode(spanmoduleId));
319         }
320         Assert(functionInstanceNode->BodySource()"body source expected");
321         CloneContext cloneContext;
322         functionInstanceNode->SetBody(static_cast<CompoundStatementNode*>(functionInstanceNode->BodySource()->Clone(cloneContext)));
323         if (functionInstanceNode->WhereConstraint())
324         {
325             std::unique_ptr<BoundConstraint> boundConstraint;
326             std::unique_ptr<Exception> conceptCheckException;
327             FileScope* classTemplateScope = new FileScope();
328             classTemplateScope->AddContainerScope(classTemplateSpecialization->GetContainerScope());
329             boundCompileUnit.AddFileScope(classTemplateScope);
330             if (!CheckConstraint(functionInstanceNode->WhereConstraint()classTemplate->UsingNodes()boundCompileUnitcontainerScopecurrentFunctionclassTemplate->TemplateParameters()
331                 templateParameterMapboundConstraintspanmoduleIdmemberFunctionconceptCheckException))
332             {
333                 boundCompileUnit.RemoveLastFileScope();
334                 if (conceptCheckException)
335                 {
336                     std::vector<std::std::pair<Spanboost::uuids::uuid>>references;
337                     references.push_back(std::make_pair(conceptCheckException->Defined()conceptCheckException->DefinedModuleId()));
338                     references.insert(references.end()conceptCheckException->References().begin()conceptCheckException->References().end());
339                     throw Exception("concept check of class template member function '" + ToUtf8(memberFunction->FullName()) + "' failed: " + conceptCheckException->Message()spanmoduleIdreferences);
340                 }
341                 else
342                 {
343                     throw Exception("concept check of class template template member function '" + ToUtf8(memberFunction->FullName()) + "' failed."spanmoduleId);
344                 }
345             }
346             else
347             {
348                 boundCompileUnit.RemoveLastFileScope();
349             }
350         }
351         std::lock_guard<std::recursive_mutex> lock(boundCompileUnit.GetModule().GetLock());
352         FunctionSymbol* master = memberFunction;
353         master->ResetImmutable();
354         memberFunction = master->Copy();
355         boundCompileUnit.GetSymbolTable().AddFunctionSymbol(std::unique_ptr<FunctionSymbol>(memberFunction));
356         master->SetImmutable();
357         symbolTable.SetCurrentCompileUnit(boundCompileUnit.GetCompileUnitNode());
358         SymbolCreatorVisitor symbolCreatorVisitor(symbolTable);
359         symbolTable.BeginContainer(memberFunction);
360         symbolTable.MapNode(functionInstanceNodememberFunction);
361         symbolCreatorVisitor.InsertTracer(functionInstanceNode->Body());
362         functionInstanceNode->Body()->Accept(symbolCreatorVisitor);
363         symbolTable.EndContainer();
364         TypeBinder typeBinder(boundCompileUnit);
365         typeBinder.SetContainerScope(memberFunction->GetContainerScope());
366         typeBinder.SetCurrentFunctionSymbol(memberFunction);
367         functionInstanceNode->Body()->Accept(typeBinder);
368         StatementBinder statementBinder(boundCompileUnit);
369         std::unique_ptr<BoundClass> boundClass(new BoundClass(classTemplateSpecialization));
370         statementBinder.SetCurrentClass(boundClass.get());
371         std::unique_ptr<BoundFunction> boundFunction(new BoundFunction(&boundCompileUnitmemberFunction));
372         statementBinder.SetCurrentFunction(boundFunction.get());
373         statementBinder.SetContainerScope(memberFunction->GetContainerScope());
374         if (memberFunction->GetSymbolType() == SymbolType::constructorSymbol)
375         {
376             ConstructorSymbol* constructorSymbol = static_cast<ConstructorSymbol*>(memberFunction);
377             Node* node = symbolTable.GetNode(memberFunction);
378             Assert(node->GetNodeType() == NodeType::constructorNode"constructor node expected");
379             ConstructorNode* constructorNode = static_cast<ConstructorNode*>(node);
380             statementBinder.SetCurrentConstructor(constructorSymbolconstructorNode);
381         }
382         else if (memberFunction->GetSymbolType() == SymbolType::destructorSymbol)
383         {
384             DestructorSymbol* destructorSymbol = static_cast<DestructorSymbol*>(memberFunction);
385             Node* node = symbolTable.GetNode(memberFunction);
386             Assert(node->GetNodeType() == NodeType::destructorNode"destructor node expected");
387             DestructorNode* destructorNode = static_cast<DestructorNode*>(node);
388             statementBinder.SetCurrentDestructor(destructorSymboldestructorNode);
389         }
390         else if (memberFunction->GetSymbolType() == SymbolType::memberFunctionSymbol)
391         {
392             MemberFunctionSymbol* memberFunctionSymbol = static_cast<MemberFunctionSymbol*>(memberFunction);
393             Node* node = symbolTable.GetNode(memberFunction);
394             Assert(node->GetNodeType() == NodeType::memberFunctionNode"member function node expected");
395             MemberFunctionNode* memberFunctionNode = static_cast<MemberFunctionNode*>(node);
396             statementBinder.SetCurrentMemberFunction(memberFunctionSymbolmemberFunctionNode);
397         }
398         functionInstanceNode->Body()->Accept(statementBinder);
399         BoundStatement* boundStatement = statementBinder.ReleaseStatement();
400         Assert(boundStatement->GetBoundNodeType() == BoundNodeType::boundCompoundStatement"bound compound statement expected");
401         BoundCompoundStatement* compoundStatement = static_cast<BoundCompoundStatement*>(boundStatement);
402         boundFunction->SetBody(std::unique_ptr<BoundCompoundStatement>(compoundStatement));
403         std::u32string instantiatedMemberFunctionMangledName = boundFunction->GetFunctionSymbol()->MangledName();
404         boundClass->AddMember(std::move(boundFunction));
405         classIdMemberFunctionIndexSet.insert(classIdMemFunIndexPair);
406         boundCompileUnit.AddBoundNode(std::move(boundClass));
407         boundCompileUnit.RemoveLastFileScope();
408         return InstantiateDestructorAndVirtualFunctions(classTemplateSpecializationcontainerScopecurrentFunctionspanmoduleId);
409     }
410     catch (const Exception& ex;)
411     {
412         std::vector<std::std::pair<Spanboost::uuids::uuid>>references;
413         references.push_back(std::make_pair(memberFunction->GetSpan()memberFunction->SourceModuleId()));
414         references.push_back(std::make_pair(ex.Defined()ex.DefinedModuleId()));
415         references.insert(references.end()ex.References().begin()ex.References().end());
416         throw Exception("could not instantiate member function '" + ToUtf8(memberFunction->FullName()) + "'. Reason: " + ex.Message()spanmoduleIdreferences);
417     }
418 }
419 
420 bool ClassTemplateRepository::InstantiateDestructorAndVirtualFunctions(ClassTemplateSpecializationSymbol* classTemplateSpecializationContainerScope* containerScopeBoundFunction* currentFunction
421     const Span& spanconst boost::uuids::uuid& moduleId)
422 {
423     for (FunctionSymbol* virtualMemberFunction : classTemplateSpecialization->Vmt())
424     {
425         if (virtualMemberFunction->Parent() == classTemplateSpecialization && !virtualMemberFunction->IsGeneratedFunction())
426         {
427             if (!Instantiate(virtualMemberFunctioncontainerScopecurrentFunctionspanmoduleId))
428             {
429                 return false;
430             }
431         }
432     }
433     if (classTemplateSpecialization->Destructor())
434     {
435         if (!classTemplateSpecialization->Destructor()->IsGeneratedFunction())
436         {
437             if (!Instantiate(classTemplateSpecialization->Destructor()containerScopecurrentFunctionspanmoduleId))
438             {
439                 return false;
440             }
441         }
442     }
443     return true;
444 }
445 
446 } } // namespace cmajor::binder