1 // =================================
  2 // Copyright (c) 2021 Seppo Laakko
  3 // Distributed under the MIT license
  4 // =================================
  5 
  6 #include <cmajor/binder/FunctionTemplateRepository.hpp>
  7 #include <cmajor/binder/BoundCompileUnit.hpp>
  8 #include <cmajor/binder/TypeBinder.hpp>
  9 #include <cmajor/binder/StatementBinder.hpp>
 10 #include <cmajor/binder/BoundStatement.hpp>
 11 #include <cmajor/symbols/SymbolCreatorVisitor.hpp>
 12 #include <cmajor/symbols/TemplateSymbol.hpp>
 13 #include <cmajor/symbols/GlobalFlags.hpp>
 14 #include <sngcm/ast/Identifier.hpp>
 15 #include <soulng/util/Util.hpp>
 16 
 17 namespace cmajor { namespace binder {
 18 
 19 using namespace soulng::util;
 20 using namespace cmajor::symbols;
 21 
 22 bool operator==(const FunctionTemplateKey& leftconst FunctionTemplateKey& right)
 23 {
 24     if (left.functionTemplate != right.functionTemplate) return false;
 25     if (left.templateArgumentTypes.size() != right.templateArgumentTypes.size()) return false;
 26     int n = left.templateArgumentTypes.size();
 27     for (int i = 0; i < n; ++i)
 28     {
 29         if (!TypesEqual(left.templateArgumentTypes[i]right.templateArgumentTypes[i])) return false;
 30     }
 31     return true;
 32 }
 33 
 34 bool operator!=(const FunctionTemplateKey& leftconst FunctionTemplateKey& right)
 35 {
 36     return !(left == right);
 37 }
 38 
 39 FunctionTemplateRepository::FunctionTemplateRepository(BoundCompileUnit& boundCompileUnit_) : boundCompileUnit(boundCompileUnit_)
 40 {
 41 }
 42 
 43 FunctionSymbol* FunctionTemplateRepository::Instantiate(FunctionSymbol* functionTemplateconst std::std::unordered_map<TemplateParameterSymbol*TypeSymbol*>&templateParameterMapping
 44     const Span& spanconst boost::uuids::uuid& moduleId)
 45 {
 46     std::vector<TypeSymbol*> templateArgumentTypes;
 47     for (TemplateParameterSymbol* templateParameter : functionTemplate->TemplateParameters())
 48     {
 49         auto it = templateParameterMapping.find(templateParameter);
 50         if (it != templateParameterMapping.cend())
 51         {
 52             TypeSymbol* templateArgumentType = it->second;
 53             templateArgumentTypes.push_back(templateArgumentType);
 54         }
 55         else
 56         {
 57             throw Exception("template parameter type not found"spanmoduleId);
 58         }
 59     }
 60     FunctionTemplateKey key(functionTemplatetemplateArgumentTypes);
 61     auto it = functionTemplateMap.find(key);
 62     if (it != functionTemplateMap.cend())
 63     {
 64         return it->second;
 65     }
 66     SymbolTable& symbolTable = boundCompileUnit.GetSymbolTable();
 67     Node* node = symbolTable.GetNodeNoThrow(functionTemplate);
 68     if (!node)
 69     {
 70         node = functionTemplate->GetFunctionNode();
 71         symbolTable.MapNode(nodefunctionTemplate);
 72         Assert(node"function node not read");
 73     }
 74     Assert(node->GetNodeType() == NodeType::functionNode"function node expected");
 75     FunctionNode* functionNode = static_cast<FunctionNode*>(node);
 76     std::unique_ptr<NamespaceNode> globalNs(new NamespaceNode(functionNode->GetSpan()functionNode->ModuleId()new IdentifierNode(functionNode->GetSpan()functionNode->ModuleId()U"")));
 77     NamespaceNode* currentNs = globalNs.get();
 78     CloneContext cloneContext;
 79     cloneContext.SetInstantiateFunctionNode();
 80     int n = functionTemplate->UsingNodes().Count();
 81     for (int i = 0; i < n; ++i)
 82     {
 83         Node* usingNode = functionTemplate->UsingNodes()[i];
 84         globalNs->AddMember(usingNode->Clone(cloneContext));
 85     }
 86     bool fileScopeAdded = false;
 87     if (!functionTemplate->Ns()->IsGlobalNamespace())
 88     {
 89         FileScope* primaryFileScope = new FileScope();
 90         primaryFileScope->AddContainerScope(functionTemplate->Ns()->GetContainerScope());
 91         boundCompileUnit.AddFileScope(primaryFileScope);
 92         fileScopeAdded = true;
 93         std::u32string fullNsName = functionTemplate->Ns()->FullName();
 94         std::vector<std::u32string> nsComponents = Split(fullNsName'.');
 95         for (const std::u32string& nsComponent : nsComponents)
 96         {
 97             NamespaceNode* nsNode = new NamespaceNode(functionNode->GetSpan()functionNode->ModuleId()new IdentifierNode(functionNode->GetSpan()functionNode->ModuleId()nsComponent));
 98             currentNs->AddMember(nsNode);
 99             currentNs = nsNode;
100         }
101     }
102     FunctionNode* functionInstanceNode = static_cast<FunctionNode*>(functionNode->Clone(cloneContext));
103     currentNs->AddMember(functionInstanceNode);
104     symbolTable.SetCurrentCompileUnit(boundCompileUnit.GetCompileUnitNode());
105     SymbolCreatorVisitor symbolCreatorVisitor(symbolTable);
106     symbolCreatorVisitor.InsertTracer(functionInstanceNode->Body());
107     globalNs->Accept(symbolCreatorVisitor);
108     Symbol* symbol = symbolTable.GetSymbol(functionInstanceNode);
109     Assert(symbol->GetSymbolType() == SymbolType::functionSymbol"function symbol expected");
110     FunctionSymbol* functionSymbol = static_cast<FunctionSymbol*>(symbol);
111     functionSymbol->SetLinkOnceOdrLinkage();
112     functionSymbol->SetTemplateSpecialization();
113     functionSymbol->SetFunctionTemplate(functionTemplate);
114     functionSymbol->SetTemplateArgumentTypes(templateArgumentTypes);
115     functionTemplateMap[key] = functionSymbol;
116     for (TemplateParameterSymbol* templateParameter : functionTemplate->TemplateParameters())
117     {
118         auto it = templateParameterMapping.find(templateParameter);
119         if (it != templateParameterMapping.cend())
120         {
121             TypeSymbol* boundType = it->second;
122             BoundTemplateParameterSymbol* boundTemplateParameter = new BoundTemplateParameterSymbol(spanmoduleIdtemplateParameter->Name());
123             boundTemplateParameter->SetType(boundType);
124             functionSymbol->AddMember(boundTemplateParameter);
125         }
126         else
127         {
128             throw Exception("template parameter type not found"spanmoduleId);
129         }
130     }
131     TypeBinder typeBinder(boundCompileUnit);
132     globalNs->Accept(typeBinder);
133     StatementBinder statementBinder(boundCompileUnit);
134     globalNs->Accept(statementBinder);
135     if (fileScopeAdded)
136     {
137         boundCompileUnit.RemoveLastFileScope();
138     }
139     boundCompileUnit.AddGlobalNs(std::move(globalNs));
140     functionSymbol->SetFlag(FunctionSymbolFlags::dontReuse);
141     if (functionTemplate->IsSystemDefault())
142     {
143         functionSymbol->SetSystemDefault();
144     }
145     boundCompileUnit.SetCanReuse(functionSymbol);
146     return functionSymbol;
147 }
148 
149 } } // namespace cmajor::binder