1 // =================================
  2 // Copyright (c) 2021 Seppo Laakko
  3 // Distributed under the MIT license
  4 // =================================
  5 
  6 #include <cmajor/binder/InlineFunctionRepository.hpp>
  7 #include <cmajor/binder/BoundCompileUnit.hpp>
  8 #include <cmajor/binder/TypeBinder.hpp>
  9 #include <cmajor/binder/StatementBinder.hpp>
 10 #include <cmajor/binder/BoundClass.hpp>
 11 #include <cmajor/binder/BoundFunction.hpp>
 12 #include <cmajor/binder/BoundStatement.hpp>
 13 #include <cmajor/symbols/SymbolCreatorVisitor.hpp>
 14 #include <sngcm/ast/Identifier.hpp>
 15 #include <soulng/util/Util.hpp>
 16 
 17 namespace cmajor { namespace binder {
 18 
 19 using namespace soulng::util;
 20 
 21 InlineFunctionRepository::InlineFunctionRepository(BoundCompileUnit& boundCompileUnit_) : boundCompileUnit(boundCompileUnit_)
 22 {
 23 }
 24 
 25 FunctionSymbol* InlineFunctionRepository::Instantiate(FunctionSymbol* inlineFunctionContainerScope* containerScopeconst Span& spanconst boost::uuids::uuid& moduleId)
 26 {
 27     if (inlineFunction->GetCompileUnit() == boundCompileUnit.GetCompileUnitNode()) return inlineFunction;
 28     while (inlineFunction->Master())
 29     {
 30         inlineFunction = inlineFunction->Master();
 31     }
 32     auto it = inlineFunctionMap.find(inlineFunction);
 33     if (it != inlineFunctionMap.cend())
 34     {
 35         return it->second;
 36     }
 37     SymbolTable& symbolTable = boundCompileUnit.GetSymbolTable();
 38     Node* node = symbolTable.GetNodeNoThrow(inlineFunction);
 39     if (!node)
 40     {
 41         node = inlineFunction->GetFunctionNode();
 42         symbolTable.MapNode(nodeinlineFunction);
 43         Assert(node"function node not read");
 44     }
 45     FunctionNode* functionNode = static_cast<FunctionNode*>(node);
 46     std::unique_ptr<NamespaceNode> globalNs(new NamespaceNode(functionNode->GetSpan()functionNode->ModuleId()new IdentifierNode(functionNode->GetSpan()functionNode->ModuleId()U"")));
 47     NamespaceNode* currentNs = globalNs.get();
 48     CloneContext cloneContext;
 49     cloneContext.SetInstantiateFunctionNode();
 50     bool fileScopeAdded = false;
 51     int n = inlineFunction->UsingNodes().Count();
 52     if (!inlineFunction->Ns()->IsGlobalNamespace() || n > 0)
 53     {
 54         FileScope* primaryFileScope = new FileScope();
 55         if (!inlineFunction->Ns()->IsGlobalNamespace())
 56         {
 57             primaryFileScope->AddContainerScope(inlineFunction->Ns()->GetContainerScope());
 58         }
 59         for (int i = 0; i < n; ++i)
 60         {
 61             Node* usingNode = inlineFunction->UsingNodes()[i];
 62             if (usingNode->GetNodeType() == NodeType::namespaceImportNode)
 63             {
 64                 primaryFileScope->InstallNamespaceImport(containerScopestatic_cast<NamespaceImportNode*>(usingNode));
 65             }
 66             else if (usingNode->GetNodeType() == NodeType::aliasNode)
 67             {
 68                 primaryFileScope->InstallAlias(containerScopestatic_cast<AliasNode*>(usingNode));
 69             }
 70         }
 71         boundCompileUnit.AddFileScope(primaryFileScope);
 72         fileScopeAdded = true;
 73         std::u32string fullNsName = inlineFunction->Ns()->FullName();
 74         std::vector<std::u32string> nsComponents = Split(fullNsName'.');
 75         for (const std::u32string& nsComponent : nsComponents)
 76         {
 77             NamespaceNode* nsNode = new NamespaceNode(functionNode->GetSpan()functionNode->ModuleId()new IdentifierNode(functionNode->GetSpan()functionNode->ModuleId()nsComponent));
 78             currentNs->AddMember(nsNode);
 79             currentNs = nsNode;
 80         }
 81     }
 82     FunctionNode* functionInstanceNode = static_cast<FunctionNode*>(functionNode->Clone(cloneContext));
 83     if (inlineFunction->IsDefault())
 84     {
 85         functionInstanceNode->SetBody(new sngcm::ast::CompoundStatementNode(spanmoduleId));
 86         inlineFunction->SetHasArtificialBody();
 87     }
 88     currentNs->AddMember(functionInstanceNode);
 89     std::lock_guard<std::recursive_mutex> lock(boundCompileUnit.GetModule().GetLock());
 90     symbolTable.SetCurrentCompileUnit(boundCompileUnit.GetCompileUnitNode());
 91     if (!inlineFunction->Parent()->IsClassTypeSymbol())
 92     {
 93         SymbolCreatorVisitor symbolCreatorVisitor(symbolTable);
 94         symbolCreatorVisitor.SetLeaveFunction();
 95         globalNs->Accept(symbolCreatorVisitor);
 96         std::unique_ptr<FunctionSymbol> functionSymbol(symbolTable.GetCreatedFunctionSymbol());
 97         if (inlineFunction->IsDefault())
 98         {
 99             functionSymbol->SetHasArtificialBody();
100         }
101         functionSymbol->SetParent(inlineFunction->Parent());
102         functionSymbol->SetLinkOnceOdrLinkage();
103         if (inlineFunction->IsSystemDefault())
104         {
105             functionSymbol->SetSystemDefault();
106         }
107         TypeBinder typeBinder(boundCompileUnit);
108         typeBinder.SetContainerScope(functionSymbol->GetContainerScope());
109         typeBinder.SetCurrentFunctionSymbol(functionSymbol.get());
110         functionInstanceNode->Accept(typeBinder);
111         StatementBinder statementBinder(boundCompileUnit);
112         std::unique_ptr<BoundFunction> boundFunction(new BoundFunction(&boundCompileUnitfunctionSymbol.get()));
113         statementBinder.SetCurrentFunction(boundFunction.get());
114         statementBinder.SetContainerScope(functionSymbol->GetContainerScope());
115         functionInstanceNode->Body()->Accept(statementBinder);
116         BoundStatement* boundStatement = statementBinder.ReleaseStatement();
117         Assert(boundStatement->GetBoundNodeType() == BoundNodeType::boundCompoundStatement"bound compound statement expected");
118         BoundCompoundStatement* compoundStatement = static_cast<BoundCompoundStatement*>(boundStatement);
119         boundFunction->SetBody(std::unique_ptr<BoundCompoundStatement>(compoundStatement));
120         boundCompileUnit.AddBoundNode(std::move(boundFunction));
121         if (fileScopeAdded)
122         {
123             boundCompileUnit.RemoveLastFileScope();
124         }
125         FunctionSymbol* result = functionSymbol.get();
126         boundCompileUnit.GetSymbolTable().AddFunctionSymbol(std::move(functionSymbol));
127         boundCompileUnit.AddGlobalNs(std::move(globalNs));
128         inlineFunctionMap[inlineFunction] = result;
129         result->SetFunctionId(inlineFunction->FunctionId());
130         result->SetMaster(inlineFunction);
131         result->SetCopy();
132         return result;
133     }
134     else
135     {
136         ClassTypeSymbol* classTypeSymbol = static_cast<ClassTypeSymbol*>(inlineFunction->Parent());
137         symbolTable.SetCurrentClass(classTypeSymbol);
138         SymbolCreatorVisitor symbolCreatorVisitor(symbolTable);
139         symbolCreatorVisitor.SetLeaveFunction();
140         globalNs->Accept(symbolCreatorVisitor);
141         std::unique_ptr<FunctionSymbol> functionSymbol(symbolTable.GetCreatedFunctionSymbol());
142         functionSymbol->SetVmtIndex(inlineFunction->VmtIndex());
143         functionSymbol->SetImtIndex(inlineFunction->ImtIndex());
144         if (inlineFunction->IsDefault())
145         {
146             functionSymbol->SetHasArtificialBody();
147         }
148         functionSymbol->SetParent(classTypeSymbol);
149         functionSymbol->SetLinkOnceOdrLinkage();
150         if (inlineFunction->IsSystemDefault())
151         {
152             functionSymbol->SetSystemDefault();
153         }
154         TypeBinder typeBinder(boundCompileUnit);
155         typeBinder.SetContainerScope(functionSymbol->GetContainerScope());
156         functionInstanceNode->Accept(typeBinder);
157         StatementBinder statementBinder(boundCompileUnit);
158         std::unique_ptr<BoundClass> boundClass(new BoundClass(classTypeSymbol));
159         boundClass->SetInlineFunctionContainer();
160         statementBinder.SetCurrentClass(boundClass.get());
161         std::unique_ptr<BoundFunction> boundFunction(new BoundFunction(&boundCompileUnitfunctionSymbol.get()));
162         statementBinder.SetCurrentFunction(boundFunction.get());
163         statementBinder.SetContainerScope(functionSymbol->GetContainerScope());
164         if (functionSymbol->GetSymbolType() == SymbolType::constructorSymbol)
165         {
166             ConstructorSymbol* constructorSymbol = static_cast<ConstructorSymbol*>(functionSymbol.get());
167             Node* node = symbolTable.GetNode(functionSymbol.get());
168             Assert(node->GetNodeType() == NodeType::constructorNode"constructor node expected");
169             ConstructorNode* constructorNode = static_cast<ConstructorNode*>(node);
170             statementBinder.SetCurrentConstructor(constructorSymbolconstructorNode);
171         }
172         else if (functionSymbol->GetSymbolType() == SymbolType::destructorSymbol)
173         {
174             DestructorSymbol* destructorSymbol = static_cast<DestructorSymbol*>(functionSymbol.get());
175             Node* node = symbolTable.GetNode(functionSymbol.get());
176             Assert(node->GetNodeType() == NodeType::destructorNode"destructor node expected");
177             DestructorNode* destructorNode = static_cast<DestructorNode*>(node);
178             statementBinder.SetCurrentDestructor(destructorSymboldestructorNode);
179         }
180         else if (functionSymbol->GetSymbolType() == SymbolType::memberFunctionSymbol)
181         {
182             MemberFunctionSymbol* memberFunctionSymbol = static_cast<MemberFunctionSymbol*>(functionSymbol.get());
183             Node* node = symbolTable.GetNode(functionSymbol.get());
184             Assert(node->GetNodeType() == NodeType::memberFunctionNode"member function node expected");
185             MemberFunctionNode* memberFunctionNode = static_cast<MemberFunctionNode*>(node);
186             statementBinder.SetCurrentMemberFunction(memberFunctionSymbolmemberFunctionNode);
187         }
188         functionInstanceNode->Body()->Accept(statementBinder);
189         BoundStatement* boundStatement = statementBinder.ReleaseStatement();
190         Assert(boundStatement->GetBoundNodeType() == BoundNodeType::boundCompoundStatement"bound compound statement expected");
191         BoundCompoundStatement* compoundStatement = static_cast<BoundCompoundStatement*>(boundStatement);
192         boundFunction->SetBody(std::unique_ptr<BoundCompoundStatement>(compoundStatement));
193         boundClass->AddMember(std::move(boundFunction));
194         boundCompileUnit.AddBoundNode(std::move(boundClass));
195         FunctionSymbol* result = functionSymbol.get();
196         boundCompileUnit.AddGlobalNs(std::move(globalNs));
197         boundCompileUnit.GetSymbolTable().AddFunctionSymbol(std::move(functionSymbol));
198         if (fileScopeAdded)
199         {
200             boundCompileUnit.RemoveLastFileScope();
201         }
202         inlineFunctionMap[inlineFunction] = result;
203         result->SetFunctionId(inlineFunction->FunctionId());
204         result->SetMaster(inlineFunction);
205         result->SetCopy();
206         return result;
207     }
208 }
209 
210 } } // namespace cmajor::binder