1 // =================================
  2 // Copyright (c) 2021 Seppo Laakko
  3 // Distributed under the MIT license
  4 // =================================
  5 
  6 #include <sngcm/ast/Merge.hpp>
  7 #include <sngcm/ast/CompileUnit.hpp>
  8 #include <sngcm/ast/Identifier.hpp>
  9 #include <sngcm/ast/Class.hpp>
 10 #include <sngcm/ast/Visitor.hpp>
 11 
 12 namespace sngcm { namespace ast {
 13 
 14 class NodeSelectorVisitor public Visitor
 15 {
 16 public:
 17     NodeSelectorVisitor(const std::u32string& nodeName_NodeType nodeType_);
 18     Node* GetSelectedNode() const { return selectedNode; }
 19     void Visit(NamespaceNode& namespaceNode) override;
 20     void Visit(ClassNode& classNode) override;
 21 private:
 22     std::u32string nodeName;
 23     NodeType nodeType;
 24     Node* selectedNode;
 25 };
 26 
 27 
 28 NodeSelectorVisitor::NodeSelectorVisitor(const std::u32string& nodeName_NodeType nodeType_) : nodeName(nodeName_)nodeType(nodeType_)selectedNode(nullptr)
 29 {
 30 }
 31 
 32 void NodeSelectorVisitor::Visit(NamespaceNode& namespaceNode)
 33 {
 34     if (nodeType == NodeType::namespaceNode)
 35     {
 36         if (namespaceNode.Id()->Str() == nodeName)
 37         {
 38             selectedNode = &namespaceNode;
 39         }
 40     }
 41     if (!selectedNode)
 42     {
 43         int n = namespaceNode.Members().Count();
 44         for (int i = 0; i < n; ++i)
 45         {
 46             namespaceNode.Members()[i]->Accept(*this);
 47             if (selectedNode) return;
 48         }
 49     }
 50 }
 51 
 52 void NodeSelectorVisitor::Visit(ClassNode& classNode)
 53 {
 54     if (nodeType == NodeType::classNode)
 55     {
 56         if (classNode.Id()->Str() == nodeName)
 57         {
 58             selectedNode = &classNode;
 59         }
 60         if (!selectedNode)
 61         {
 62             int n = classNode.Members().Count();
 63             for (int i = 0; i < n; ++i)
 64             {
 65                 classNode.Members()[i]->Accept(*this);
 66                 if (selectedNode) return;
 67             }
 68         }
 69     }
 70 }
 71 
 72 class MergeVisitor public Visitor
 73 {
 74 public:
 75     MergeVisitor(Node* targetContainer_);
 76     void Visit(NamespaceNode& namespaceNode) override;
 77     void Visit(ClassNode& classNode) override;
 78 private:
 79     Node* targetContainer;
 80 };
 81 
 82 MergeVisitor::MergeVisitor(Node* targetContainer_) : targetContainer(targetContainer_)
 83 {
 84 }
 85 
 86 void MergeVisitor::Visit(NamespaceNode& namespaceNode)
 87 {
 88     bool added = false;
 89     if (!namespaceNode.Id()->Str().empty())
 90     {
 91         NodeSelectorVisitor visitor(namespaceNode.Id()->Str()NodeType::namespaceNode);
 92         targetContainer->Accept(visitor);
 93         Node* selectedNode = visitor.GetSelectedNode();
 94         if (selectedNode)
 95         {
 96             targetContainer = selectedNode;
 97         }
 98         else
 99         {
100             if (targetContainer->GetNodeType() == NodeType::namespaceNode)
101             {
102                 NamespaceNode* targetNamespace = static_cast<NamespaceNode*>(targetContainer);
103                 CloneContext cloneContext;
104                 Node* clonedSource = namespaceNode.Clone(cloneContext);
105                 targetNamespace->AddMember(clonedSource);
106                 added = true;
107             }
108         }
109     }
110     if (!added && targetContainer->GetNodeType() == NodeType::namespaceNode)
111     {
112         NamespaceNode* targetNamespace = static_cast<NamespaceNode*>(targetContainer);
113         int n = namespaceNode.Members().Count();
114         for (int i = 0; i < n; ++i)
115         {
116             Node* member = namespaceNode.Members()[i];
117             if (member->GetNodeType() == NodeType::namespaceNode)
118             {
119                 member->Accept(*this);
120             }
121             else if (member->GetNodeType() == NodeType::classNode)
122             {
123                 member->Accept(*this);
124             }
125             else
126             {
127                 CloneContext cloneContext;
128                 Node* clonedSource = member->Clone(cloneContext);
129                 targetNamespace->AddMember(clonedSource);
130             }
131         }
132     }
133 }
134 
135 void MergeVisitor::Visit(ClassNode& classNode)
136 {
137     bool added = false;
138     NodeSelectorVisitor visitor(classNode.Id()->Str()NodeType::classNode);
139     targetContainer->Accept(visitor);
140     Node* selectedNode = visitor.GetSelectedNode();
141     if (selectedNode)
142     {
143         targetContainer = selectedNode;
144     }
145     else
146     {
147         if (targetContainer->GetNodeType() == NodeType::namespaceNode)
148         {
149             NamespaceNode* targetNamespace = static_cast<NamespaceNode*>(targetContainer);
150             CloneContext cloneContext;
151             Node* clonedSource = classNode.Clone(cloneContext);
152             targetNamespace->AddMember(clonedSource);
153             added = true;
154         }
155         else if (targetContainer->GetNodeType() == NodeType::classNode)
156         {
157             ClassNode* targetClass = static_cast<ClassNode*>(targetContainer);
158             CloneContext cloneContext;
159             Node* clonedSource = classNode.Clone(cloneContext);
160             targetClass->AddMember(clonedSource);
161             added = true;
162         }
163     }
164     if (!added && targetContainer->GetNodeType() == NodeType::classNode)
165     {
166         ClassNode* targetClass = static_cast<ClassNode*>(targetContainer);
167         int n = classNode.Members().Count();
168         for (int i = 0; i < n; ++i)
169         {
170             Node* member = classNode.Members()[i];
171             if (member->GetNodeType() == NodeType::classNode)
172             {
173                 member->Accept(*this);
174             }
175             else
176             {
177                 CloneContext cloneContext;
178                 Node* clonedSource = member->Clone(cloneContext);
179                 targetClass->AddMember(clonedSource);
180             }
181         }
182     }
183 }
184 
185 void Merge(CompileUnitNode& sourceCompileUnitNode& target)
186 {
187     MergeVisitor visitor(target.GlobalNs());
188     source.GlobalNs()->Accept(visitor);
189 }
190 
191 } } // namespace sngcm::ast