1
2
3
4
5
6 #include <cmajor/binder/FunctionTemplateRepository.hpp>
7 #include <cmajor/binder/BoundCompileUnit.hpp>
8 #include <cmajor/binder/TypeBinder.hpp>
9 #include <cmajor/binder/StatementBinder.hpp>
10 #include <cmajor/binder/BoundStatement.hpp>
11 #include <cmajor/symbols/SymbolCreatorVisitor.hpp>
12 #include <cmajor/symbols/TemplateSymbol.hpp>
13 #include <cmajor/symbols/GlobalFlags.hpp>
14 #include <sngcm/ast/Identifier.hpp>
15 #include <soulng/util/Util.hpp>
16
17 namespace cmajor { namespace binder {
18
19 using namespace soulng::util;
20 using namespace cmajor::symbols;
21
22 bool operator==(const FunctionTemplateKey& left, const FunctionTemplateKey& right)
23 {
24 if (left.functionTemplate != right.functionTemplate) return false;
25 if (left.templateArgumentTypes.size() != right.templateArgumentTypes.size()) return false;
26 int n = left.templateArgumentTypes.size();
27 for (int i = 0; i < n; ++i)
28 {
29 if (!TypesEqual(left.templateArgumentTypes[i], right.templateArgumentTypes[i])) return false;
30 }
31 return true;
32 }
33
34 bool operator!=(const FunctionTemplateKey& left, const FunctionTemplateKey& right)
35 {
36 return !(left == right);
37 }
38
39 FunctionTemplateRepository::FunctionTemplateRepository(BoundCompileUnit& boundCompileUnit_) : boundCompileUnit(boundCompileUnit_)
40 {
41 }
42
43 FunctionSymbol* FunctionTemplateRepository::Instantiate(FunctionSymbol* functionTemplate, const std::std::unordered_map<TemplateParameterSymbol*, TypeSymbol*>&templateParameterMapping,
44 const Span& span, const boost::uuids::uuid& moduleId)
45 {
46 std::vector<TypeSymbol*> templateArgumentTypes;
47 for (TemplateParameterSymbol* templateParameter : functionTemplate->TemplateParameters())
48 {
49 auto it = templateParameterMapping.find(templateParameter);
50 if (it != templateParameterMapping.cend())
51 {
52 TypeSymbol* templateArgumentType = it->second;
53 templateArgumentTypes.push_back(templateArgumentType);
54 }
55 else
56 {
57 throw Exception("template parameter type not found", span, moduleId);
58 }
59 }
60 FunctionTemplateKey key(functionTemplate, templateArgumentTypes);
61 auto it = functionTemplateMap.find(key);
62 if (it != functionTemplateMap.cend())
63 {
64 return it->second;
65 }
66 SymbolTable& symbolTable = boundCompileUnit.GetSymbolTable();
67 Node* node = symbolTable.GetNodeNoThrow(functionTemplate);
68 if (!node)
69 {
70 node = functionTemplate->GetFunctionNode();
71 symbolTable.MapNode(node, functionTemplate);
72 Assert(node, "function node not read");
73 }
74 Assert(node->GetNodeType() == NodeType::functionNode, "function node expected");
75 FunctionNode* functionNode = static_cast<FunctionNode*>(node);
76 std::unique_ptr<NamespaceNode> globalNs(new NamespaceNode(functionNode->GetSpan(), functionNode->ModuleId(), new IdentifierNode(functionNode->GetSpan(), functionNode->ModuleId(), U"")));
77 NamespaceNode* currentNs = globalNs.get();
78 CloneContext cloneContext;
79 cloneContext.SetInstantiateFunctionNode();
80 int n = functionTemplate->UsingNodes().Count();
81 for (int i = 0; i < n; ++i)
82 {
83 Node* usingNode = functionTemplate->UsingNodes()[i];
84 globalNs->AddMember(usingNode->Clone(cloneContext));
85 }
86 bool fileScopeAdded = false;
87 if (!functionTemplate->Ns()->IsGlobalNamespace())
88 {
89 FileScope* primaryFileScope = new FileScope();
90 primaryFileScope->AddContainerScope(functionTemplate->Ns()->GetContainerScope());
91 boundCompileUnit.AddFileScope(primaryFileScope);
92 fileScopeAdded = true;
93 std::u32string fullNsName = functionTemplate->Ns()->FullName();
94 std::vector<std::u32string> nsComponents = Split(fullNsName, '.');
95 for (const std::u32string& nsComponent : nsComponents)
96 {
97 NamespaceNode* nsNode = new NamespaceNode(functionNode->GetSpan(), functionNode->ModuleId(), new IdentifierNode(functionNode->GetSpan(), functionNode->ModuleId(), nsComponent));
98 currentNs->AddMember(nsNode);
99 currentNs = nsNode;
100 }
101 }
102 FunctionNode* functionInstanceNode = static_cast<FunctionNode*>(functionNode->Clone(cloneContext));
103 currentNs->AddMember(functionInstanceNode);
104 symbolTable.SetCurrentCompileUnit(boundCompileUnit.GetCompileUnitNode());
105 SymbolCreatorVisitor symbolCreatorVisitor(symbolTable);
106 symbolCreatorVisitor.InsertTracer(functionInstanceNode->Body());
107 globalNs->Accept(symbolCreatorVisitor);
108 Symbol* symbol = symbolTable.GetSymbol(functionInstanceNode);
109 Assert(symbol->GetSymbolType() == SymbolType::functionSymbol, "function symbol expected");
110 FunctionSymbol* functionSymbol = static_cast<FunctionSymbol*>(symbol);
111 functionSymbol->SetLinkOnceOdrLinkage();
112 functionSymbol->SetTemplateSpecialization();
113 functionSymbol->SetFunctionTemplate(functionTemplate);
114 functionSymbol->SetTemplateArgumentTypes(templateArgumentTypes);
115 functionTemplateMap[key] = functionSymbol;
116 for (TemplateParameterSymbol* templateParameter : functionTemplate->TemplateParameters())
117 {
118 auto it = templateParameterMapping.find(templateParameter);
119 if (it != templateParameterMapping.cend())
120 {
121 TypeSymbol* boundType = it->second;
122 BoundTemplateParameterSymbol* boundTemplateParameter = new BoundTemplateParameterSymbol(span, moduleId, templateParameter->Name());
123 boundTemplateParameter->SetType(boundType);
124 functionSymbol->AddMember(boundTemplateParameter);
125 }
126 else
127 {
128 throw Exception("template parameter type not found", span, moduleId);
129 }
130 }
131 TypeBinder typeBinder(boundCompileUnit);
132 globalNs->Accept(typeBinder);
133 StatementBinder statementBinder(boundCompileUnit);
134 globalNs->Accept(statementBinder);
135 if (fileScopeAdded)
136 {
137 boundCompileUnit.RemoveLastFileScope();
138 }
139 boundCompileUnit.AddGlobalNs(std::move(globalNs));
140 functionSymbol->SetFlag(FunctionSymbolFlags::dontReuse);
141 if (functionTemplate->IsSystemDefault())
142 {
143 functionSymbol->SetSystemDefault();
144 }
145 boundCompileUnit.SetCanReuse(functionSymbol);
146 return functionSymbol;
147 }
148
149 } }