1
2
3
4
5
6 #include <cmajor/binder/ClassTemplateRepository.hpp>
7 #include <cmajor/binder/BoundCompileUnit.hpp>
8 #include <cmajor/binder/TypeResolver.hpp>
9 #include <cmajor/binder/TypeBinder.hpp>
10 #include <cmajor/binder/StatementBinder.hpp>
11 #include <cmajor/binder/BoundClass.hpp>
12 #include <cmajor/binder/BoundFunction.hpp>
13 #include <cmajor/binder/BoundStatement.hpp>
14 #include <cmajor/binder/Concept.hpp>
15 #include <cmajor/symbols/TemplateSymbol.hpp>
16 #include <cmajor/symbols/SymbolCreatorVisitor.hpp>
17 #include <sngcm/ast/Identifier.hpp>
18 #include <soulng/util/Util.hpp>
19 #include <soulng/util/Unicode.hpp>
20
21 namespace cmajor { namespace binder {
22
23 using namespace soulng::util;
24 using namespace soulng::unicode;
25
26 size_t ClassIdMemberFunctionIndexHash::operator()(const std::std::pair<boost::uuids::uuid, int>&p) const
27 {
28 return boost::hash<boost::uuids::uuid>()(p.first) ^ boost::hash<int>()(p.second);
29 }
30
31 ClassTemplateRepository::ClassTemplateRepository(BoundCompileUnit& boundCompileUnit_) : boundCompileUnit(boundCompileUnit_)
32 {
33 }
34
35 void ClassTemplateRepository::ResolveDefaultTemplateArguments(std::std::vector<TypeSymbol*>&templateArgumentTypes, ClassTypeSymbol*classTemplate, ContainerScope*containerScope,
36 const Span& span, const boost::uuids::uuid& moduleId)
37 {
38 int n = classTemplate->TemplateParameters().size();
39 int m = templateArgumentTypes.size();
40 if (m == n) return;
41 SymbolTable& symbolTable = boundCompileUnit.GetSymbolTable();
42 Node* node = symbolTable.GetNodeNoThrow(classTemplate);
43 if (!node)
44 {
45 node = classTemplate->GetClassNode();
46 Assert(node, "class node not read");
47 }
48 Assert(node->GetNodeType() == NodeType::classNode, "class node expected");
49 ClassNode* classNode = static_cast<ClassNode*>(node);
50 int numFileScopeAdded = 0;
51 int nu = classTemplate->UsingNodes().Count();
52 if (nu > 0)
53 {
54 FileScope* fileScope = new FileScope();
55 for (int i = 0; i < nu; ++i)
56 {
57 Node* usingNode = classTemplate->UsingNodes()[i];
58 if (usingNode->GetNodeType() == NodeType::namespaceImportNode)
59 {
60 NamespaceImportNode* namespaceImportNode = static_cast<NamespaceImportNode*>(usingNode);
61 fileScope->InstallNamespaceImport(containerScope, namespaceImportNode);
62 }
63 else if (usingNode->GetNodeType() == NodeType::aliasNode)
64 {
65 AliasNode* aliasNode = static_cast<AliasNode*>(usingNode);
66 fileScope->InstallAlias(containerScope, aliasNode);
67 }
68 }
69 boundCompileUnit.AddFileScope(fileScope);
70 ++numFileScopeAdded;
71 }
72 if (!classTemplate->Ns()->IsGlobalNamespace())
73 {
74 FileScope* primaryFileScope = new FileScope();
75 primaryFileScope->AddContainerScope(classTemplate->Ns()->GetContainerScope());
76 boundCompileUnit.AddFileScope(primaryFileScope);
77 ++numFileScopeAdded;
78 }
79 ContainerScope resolveScope;
80 resolveScope.SetParentScope(containerScope);
81 std::vector<std::std::unique_ptr<BoundTemplateParameterSymbol>>boundTemplateParameters;
82 for (int i = 0; i < n; ++i)
83 {
84 TemplateParameterSymbol* templateParameterSymbol = classTemplate->TemplateParameters()[i];
85 BoundTemplateParameterSymbol* boundTemplateParameter = new BoundTemplateParameterSymbol(span, moduleId, templateParameterSymbol->Name());
86 boundTemplateParameters.push_back(std::unique_ptr<BoundTemplateParameterSymbol>(boundTemplateParameter));
87 if (i < m)
88 {
89 boundTemplateParameter->SetType(templateArgumentTypes[i]);
90 resolveScope.Install(boundTemplateParameter);
91 }
92 else
93 {
94 if (i >= classNode->TemplateParameters().Count())
95 {
96 throw Exception("too few template arguments", span, moduleId);
97 }
98 Node* defaultTemplateArgumentNode = classNode->TemplateParameters()[i]->DefaultTemplateArgument();
99 if (!defaultTemplateArgumentNode)
100 {
101 throw Exception("too few template arguments", span, moduleId);
102 }
103 TypeSymbol* templateArgumentType = ResolveType(defaultTemplateArgumentNode, boundCompileUnit, &resolveScope);
104 templateArgumentTypes.push_back(templateArgumentType);
105 }
106 }
107 for (int i = 0; i < numFileScopeAdded; ++i)
108 {
109 boundCompileUnit.RemoveLastFileScope();
110 }
111 }
112
113 void ClassTemplateRepository::BindClassTemplateSpecialization(ClassTemplateSpecializationSymbol* classTemplateSpecialization, ContainerScope* containerScope,
114 const Span& span, const boost::uuids::uuid& moduleId)
115 {
116 if (classTemplateSpecialization->IsBound()) return;
117 if (classTemplateSpecialization->FullName() == U"String<char>")
118 {
119 int x = 0;
120 }
121 SymbolTable& symbolTable = boundCompileUnit.GetSymbolTable();
122 ClassTypeSymbol* classTemplate = classTemplateSpecialization->GetClassTemplate();
123 Node* node = symbolTable.GetNodeNoThrow(classTemplate);
124 if (!node)
125 {
126 node = classTemplate->GetClassNode();
127 Assert(node, "class node not read");
128 }
129 Assert(node->GetNodeType() == NodeType::classNode, "class node expected");
130 ClassNode* classNode = static_cast<ClassNode*>(node);
131 std::unique_ptr<NamespaceNode> globalNs(new NamespaceNode(classNode->GetSpan(), classNode->ModuleId(), new IdentifierNode(classNode->GetSpan(), classNode->ModuleId(), U"")));
132 NamespaceNode* currentNs = globalNs.get();
133 CloneContext cloneContext;
134 cloneContext.SetInstantiateClassNode();
135 int nu = classTemplate->UsingNodes().Count();
136 for (int i = 0; i < nu; ++i)
137 {
138 Node* usingNode = classTemplate->UsingNodes()[i];
139 globalNs->AddMember(usingNode->Clone(cloneContext));
140 }
141 bool fileScopeAdded = false;
142 if (!classTemplate->Ns()->IsGlobalNamespace())
143 {
144 FileScope* primaryFileScope = new FileScope();
145 primaryFileScope->AddContainerScope(classTemplate->Ns()->GetContainerScope());
146 boundCompileUnit.AddFileScope(primaryFileScope);
147 fileScopeAdded = true;
148 std::u32string fullNsName = classTemplate->Ns()->FullName();
149 std::vector<std::u32string> nsComponents = Split(fullNsName, '.');
150 for (const std::u32string& nsComponent : nsComponents)
151 {
152 NamespaceNode* nsNode = new NamespaceNode(classNode->GetSpan(), classNode->ModuleId(), new IdentifierNode(classNode->GetSpan(), classNode->ModuleId(), nsComponent));
153 currentNs->AddMember(nsNode);
154 currentNs = nsNode;
155 }
156 }
157 ClassNode* classInstanceNode = static_cast<ClassNode*>(classNode->Clone(cloneContext));
158 currentNs->AddMember(classInstanceNode);
159 int n = classTemplate->TemplateParameters().size();
160 int m = classTemplateSpecialization->TemplateArgumentTypes().size();
161 if (n != m)
162 {
163 throw Exception("wrong number of template arguments", span, moduleId);
164 }
165 bool templateParameterBinding = false;
166 ContainerScope resolveScope;
167 resolveScope.SetParentScope(containerScope);
168 for (int i = 0; i < n; ++i)
169 {
170 TemplateParameterSymbol* templateParameter = classTemplate->TemplateParameters()[i];
171 BoundTemplateParameterSymbol* boundTemplateParameter = new BoundTemplateParameterSymbol(span, moduleId, templateParameter->Name());
172 boundTemplateParameter->SetParent(classTemplateSpecialization);
173 TypeSymbol* templateArgumentType = classTemplateSpecialization->TemplateArgumentTypes()[i];
174 boundTemplateParameter->SetType(templateArgumentType);
175 if (templateArgumentType->GetSymbolType() == SymbolType::templateParameterSymbol)
176 {
177 templateParameterBinding = true;
178 if (classTemplateSpecialization->IsPrototype())
179 {
180 if (classTemplateSpecialization->IsProject())
181 {
182 resolveScope.Install(boundTemplateParameter);
183 TemplateParameterNode* templateParameterNode = classNode->TemplateParameters()[i];
184 Node* defaultTemplateArgumentNode = templateParameterNode->DefaultTemplateArgument();
185 if (defaultTemplateArgumentNode)
186 {
187 TypeSymbol* templateArgumentType = ResolveType(defaultTemplateArgumentNode, boundCompileUnit, &resolveScope);
188 templateParameter->SetDefaultType(templateArgumentType);
189 }
190 }
191 }
192 }
193 classTemplateSpecialization->AddMember(boundTemplateParameter);
194 }
195 symbolTable.SetCurrentCompileUnit(boundCompileUnit.GetCompileUnitNode());
196 SymbolCreatorVisitor symbolCreatorVisitor(symbolTable);
197 symbolCreatorVisitor.SetClassInstanceNode(classInstanceNode);
198 symbolCreatorVisitor.SetClassTemplateSpecialization(classTemplateSpecialization);
199 globalNs->Accept(symbolCreatorVisitor);
200 TypeBinder typeBinder(boundCompileUnit);
201 if (templateParameterBinding)
202 {
203 typeBinder.CreateMemberSymbols();
204 }
205 typeBinder.SetContainerScope(classTemplateSpecialization->GetContainerScope());
206 globalNs->Accept(typeBinder);
207 if (templateParameterBinding)
208 {
209 classTemplateSpecialization->SetGlobalNs(std::move(globalNs));
210 if (fileScopeAdded)
211 {
212 boundCompileUnit.RemoveLastFileScope();
213 }
214 }
215 else if (boundCompileUnit.BindingTypes())
216 {
217 classTemplateSpecialization->SetGlobalNs(std::move(globalNs));
218 classTemplateSpecialization->SetStatementsNotBound();
219 if (fileScopeAdded)
220 {
221 FileScope* fileScope = boundCompileUnit.ReleaseLastFileScope();
222 classTemplateSpecialization->SetFileScope(fileScope);
223 }
224 }
225 else
226 {
227 StatementBinder statementBinder(boundCompileUnit);
228 globalNs->Accept(statementBinder);
229 classTemplateSpecialization->SetGlobalNs(std::move(globalNs));
230 if (fileScopeAdded)
231 {
232 boundCompileUnit.RemoveLastFileScope();
233 }
234 }
235 }
236
237 bool ClassTemplateRepository::Instantiate(FunctionSymbol* memberFunction, ContainerScope* containerScope, BoundFunction* currentFunction, const Span& span, const boost::uuids::uuid& moduleId)
238 {
239 if (instantiatedMemberFunctions.find(memberFunction) != instantiatedMemberFunctions.cend()) return true;
240 instantiatedMemberFunctions.insert(memberFunction);
241 try
242 {
243 SymbolTable& symbolTable = boundCompileUnit.GetSymbolTable();
244 Symbol* parent = memberFunction->Parent();
245 Assert(parent->GetSymbolType() == SymbolType::classTemplateSpecializationSymbol, "class template specialization expected");
246 ClassTemplateSpecializationSymbol* classTemplateSpecialization = static_cast<ClassTemplateSpecializationSymbol*>(parent);
247 std::pair<boost::uuids::uuid, int> classIdMemFunIndexPair = std::make_pair(classTemplateSpecialization->TypeId(), memberFunction->GetIndex());
248 if (classIdMemberFunctionIndexSet.find(classIdMemFunIndexPair) != classIdMemberFunctionIndexSet.cend())
249 {
250
251
252 instantiatedMemberFunctions.insert(memberFunction);
253 return true;
254 }
255 Assert(classTemplateSpecialization->IsBound(), "class template specialization not bound");
256 Node* node = symbolTable.GetNodeNoThrow(memberFunction);
257 if (!node)
258 {
259 return false;
260 }
261 boundCompileUnit.FinalizeBinding(classTemplateSpecialization);
262 ClassTypeSymbol* classTemplate = classTemplateSpecialization->GetClassTemplate();
263 std::unordered_map<TemplateParameterSymbol*, TypeSymbol*> templateParameterMap;
264 int n = classTemplateSpecialization->TemplateArgumentTypes().size();
265 for (int i = 0; i < n; ++i)
266 {
267 TemplateParameterSymbol* templateParameter = classTemplate->TemplateParameters()[i];
268 TypeSymbol* templateArgument = classTemplateSpecialization->TemplateArgumentTypes()[i];
269 templateParameterMap[templateParameter] = templateArgument;
270 }
271 if (!classTemplateSpecialization->IsConstraintChecked())
272 {
273 classTemplateSpecialization->SetConstraintChecked();
274 if (classTemplate->Constraint())
275 {
276 std::unique_ptr<BoundConstraint> boundConstraint;
277 std::unique_ptr<Exception> conceptCheckException;
278 if (!CheckConstraint(classTemplate->Constraint(), classTemplate->UsingNodes(), boundCompileUnit, containerScope, currentFunction, classTemplate->TemplateParameters(),
279 templateParameterMap, boundConstraint, span, moduleId, memberFunction, conceptCheckException))
280 {
281 if (conceptCheckException)
282 {
283 throw Exception("concept check of class template specialization '" + ToUtf8(classTemplateSpecialization->FullName()) + "' failed: " + conceptCheckException->Message(), span,
284 moduleId, conceptCheckException->References());
285 }
286 else
287 {
288 throw Exception("concept check of class template specialization '" + ToUtf8(classTemplateSpecialization->FullName()) + "' failed.", span, moduleId);
289 }
290 }
291 }
292 }
293 FileScope* fileScope = new FileScope();
294 int nu = classTemplate->UsingNodes().Count();
295 for (int i = 0; i < nu; ++i)
296 {
297 Node* usingNode = classTemplate->UsingNodes()[i];
298 if (usingNode->GetNodeType() == NodeType::namespaceImportNode)
299 {
300 NamespaceImportNode* namespaceImportNode = static_cast<NamespaceImportNode*>(usingNode);
301 fileScope->InstallNamespaceImport(containerScope, namespaceImportNode);
302 }
303 else if (usingNode->GetNodeType() == NodeType::aliasNode)
304 {
305 AliasNode* aliasNode = static_cast<AliasNode*>(usingNode);
306 fileScope->InstallAlias(containerScope, aliasNode);
307 }
308 }
309 if (!classTemplate->Ns()->IsGlobalNamespace())
310 {
311 fileScope->AddContainerScope(classTemplate->Ns()->GetContainerScope());
312 }
313 boundCompileUnit.AddFileScope(fileScope);
314 Assert(node->IsFunctionNode(), "function node expected");
315 FunctionNode* functionInstanceNode = static_cast<FunctionNode*>(node);
316 if (memberFunction->IsDefault())
317 {
318 functionInstanceNode->SetBodySource(new sngcm::ast::CompoundStatementNode(span, moduleId));
319 }
320 Assert(functionInstanceNode->BodySource(), "body source expected");
321 CloneContext cloneContext;
322 functionInstanceNode->SetBody(static_cast<CompoundStatementNode*>(functionInstanceNode->BodySource()->Clone(cloneContext)));
323 if (functionInstanceNode->WhereConstraint())
324 {
325 std::unique_ptr<BoundConstraint> boundConstraint;
326 std::unique_ptr<Exception> conceptCheckException;
327 FileScope* classTemplateScope = new FileScope();
328 classTemplateScope->AddContainerScope(classTemplateSpecialization->GetContainerScope());
329 boundCompileUnit.AddFileScope(classTemplateScope);
330 if (!CheckConstraint(functionInstanceNode->WhereConstraint(), classTemplate->UsingNodes(), boundCompileUnit, containerScope, currentFunction, classTemplate->TemplateParameters(),
331 templateParameterMap, boundConstraint, span, moduleId, memberFunction, conceptCheckException))
332 {
333 boundCompileUnit.RemoveLastFileScope();
334 if (conceptCheckException)
335 {
336 std::vector<std::std::pair<Span, boost::uuids::uuid>>references;
337 references.push_back(std::make_pair(conceptCheckException->Defined(), conceptCheckException->DefinedModuleId()));
338 references.insert(references.end(), conceptCheckException->References().begin(), conceptCheckException->References().end());
339 throw Exception("concept check of class template member function '" + ToUtf8(memberFunction->FullName()) + "' failed: " + conceptCheckException->Message(), span, moduleId, references);
340 }
341 else
342 {
343 throw Exception("concept check of class template template member function '" + ToUtf8(memberFunction->FullName()) + "' failed.", span, moduleId);
344 }
345 }
346 else
347 {
348 boundCompileUnit.RemoveLastFileScope();
349 }
350 }
351 std::lock_guard<std::recursive_mutex> lock(boundCompileUnit.GetModule().GetLock());
352 FunctionSymbol* master = memberFunction;
353 master->ResetImmutable();
354 memberFunction = master->Copy();
355 boundCompileUnit.GetSymbolTable().AddFunctionSymbol(std::unique_ptr<FunctionSymbol>(memberFunction));
356 master->SetImmutable();
357 symbolTable.SetCurrentCompileUnit(boundCompileUnit.GetCompileUnitNode());
358 SymbolCreatorVisitor symbolCreatorVisitor(symbolTable);
359 symbolTable.BeginContainer(memberFunction);
360 symbolTable.MapNode(functionInstanceNode, memberFunction);
361 symbolCreatorVisitor.InsertTracer(functionInstanceNode->Body());
362 functionInstanceNode->Body()->Accept(symbolCreatorVisitor);
363 symbolTable.EndContainer();
364 TypeBinder typeBinder(boundCompileUnit);
365 typeBinder.SetContainerScope(memberFunction->GetContainerScope());
366 typeBinder.SetCurrentFunctionSymbol(memberFunction);
367 functionInstanceNode->Body()->Accept(typeBinder);
368 StatementBinder statementBinder(boundCompileUnit);
369 std::unique_ptr<BoundClass> boundClass(new BoundClass(classTemplateSpecialization));
370 statementBinder.SetCurrentClass(boundClass.get());
371 std::unique_ptr<BoundFunction> boundFunction(new BoundFunction(&boundCompileUnit, memberFunction));
372 statementBinder.SetCurrentFunction(boundFunction.get());
373 statementBinder.SetContainerScope(memberFunction->GetContainerScope());
374 if (memberFunction->GetSymbolType() == SymbolType::constructorSymbol)
375 {
376 ConstructorSymbol* constructorSymbol = static_cast<ConstructorSymbol*>(memberFunction);
377 Node* node = symbolTable.GetNode(memberFunction);
378 Assert(node->GetNodeType() == NodeType::constructorNode, "constructor node expected");
379 ConstructorNode* constructorNode = static_cast<ConstructorNode*>(node);
380 statementBinder.SetCurrentConstructor(constructorSymbol, constructorNode);
381 }
382 else if (memberFunction->GetSymbolType() == SymbolType::destructorSymbol)
383 {
384 DestructorSymbol* destructorSymbol = static_cast<DestructorSymbol*>(memberFunction);
385 Node* node = symbolTable.GetNode(memberFunction);
386 Assert(node->GetNodeType() == NodeType::destructorNode, "destructor node expected");
387 DestructorNode* destructorNode = static_cast<DestructorNode*>(node);
388 statementBinder.SetCurrentDestructor(destructorSymbol, destructorNode);
389 }
390 else if (memberFunction->GetSymbolType() == SymbolType::memberFunctionSymbol)
391 {
392 MemberFunctionSymbol* memberFunctionSymbol = static_cast<MemberFunctionSymbol*>(memberFunction);
393 Node* node = symbolTable.GetNode(memberFunction);
394 Assert(node->GetNodeType() == NodeType::memberFunctionNode, "member function node expected");
395 MemberFunctionNode* memberFunctionNode = static_cast<MemberFunctionNode*>(node);
396 statementBinder.SetCurrentMemberFunction(memberFunctionSymbol, memberFunctionNode);
397 }
398 functionInstanceNode->Body()->Accept(statementBinder);
399 BoundStatement* boundStatement = statementBinder.ReleaseStatement();
400 Assert(boundStatement->GetBoundNodeType() == BoundNodeType::boundCompoundStatement, "bound compound statement expected");
401 BoundCompoundStatement* compoundStatement = static_cast<BoundCompoundStatement*>(boundStatement);
402 boundFunction->SetBody(std::unique_ptr<BoundCompoundStatement>(compoundStatement));
403 std::u32string instantiatedMemberFunctionMangledName = boundFunction->GetFunctionSymbol()->MangledName();
404 boundClass->AddMember(std::move(boundFunction));
405 classIdMemberFunctionIndexSet.insert(classIdMemFunIndexPair);
406 boundCompileUnit.AddBoundNode(std::move(boundClass));
407 boundCompileUnit.RemoveLastFileScope();
408 return InstantiateDestructorAndVirtualFunctions(classTemplateSpecialization, containerScope, currentFunction, span, moduleId);
409 }
410 catch (const Exception& ex;)
411 {
412 std::vector<std::std::pair<Span, boost::uuids::uuid>>references;
413 references.push_back(std::make_pair(memberFunction->GetSpan(), memberFunction->SourceModuleId()));
414 references.push_back(std::make_pair(ex.Defined(), ex.DefinedModuleId()));
415 references.insert(references.end(), ex.References().begin(), ex.References().end());
416 throw Exception("could not instantiate member function '" + ToUtf8(memberFunction->FullName()) + "'. Reason: " + ex.Message(), span, moduleId, references);
417 }
418 }
419
420 bool ClassTemplateRepository::InstantiateDestructorAndVirtualFunctions(ClassTemplateSpecializationSymbol* classTemplateSpecialization, ContainerScope* containerScope, BoundFunction* currentFunction,
421 const Span& span, const boost::uuids::uuid& moduleId)
422 {
423 for (FunctionSymbol* virtualMemberFunction : classTemplateSpecialization->Vmt())
424 {
425 if (virtualMemberFunction->Parent() == classTemplateSpecialization && !virtualMemberFunction->IsGeneratedFunction())
426 {
427 if (!Instantiate(virtualMemberFunction, containerScope, currentFunction, span, moduleId))
428 {
429 return false;
430 }
431 }
432 }
433 if (classTemplateSpecialization->Destructor())
434 {
435 if (!classTemplateSpecialization->Destructor()->IsGeneratedFunction())
436 {
437 if (!Instantiate(classTemplateSpecialization->Destructor(), containerScope, currentFunction, span, moduleId))
438 {
439 return false;
440 }
441 }
442 }
443 return true;
444 }
445
446 } }