1
2
3
4
5
6 #include <sngcpp/binder/VirtualBinder.hpp>
7 #include <sngcpp/symbols/ParameterSymbol.hpp>
8
9 namespace sngcpp { namespace binder {
10
11 bool Overrides(sngcpp::symbols::FunctionSymbol* derivedFun, sngcpp::symbols::FunctionSymbol* baseFun)
12 {
13 if (derivedFun == baseFun) return false;
14 if (derivedFun->Arity() != baseFun->Arity()) return false;
15 if (derivedFun->Name() != baseFun->Name()) return false;
16 int n = derivedFun->Arity();
17 for (int i = 0; i < n; ++i)
18 {
19 sngcpp::symbols::ParameterSymbol* derivedParam = derivedFun->Parameters()[i];
20 sngcpp::symbols::ParameterSymbol* baseParam = baseFun->Parameters()[i];
21 sngcpp::symbols::TypeSymbol* derivedType = derivedParam->GetType();
22 sngcpp::symbols::TypeSymbol* baseType = baseParam->GetType();
23 if (derivedType->Id() != baseType->Id()) return false;
24 }
25 return true;
26 }
27
28 void ResolveOverrideSets(sngcpp::symbols::FunctionSymbol* derivedFun, sngcpp::symbols::ClassTypeSymbol* parentClass)
29 {
30 for (sngcpp::symbols::TypeSymbol* baseClassType : parentClass->BaseClasses())
31 {
32 if (baseClassType->IsClassTypeSymbol())
33 {
34 sngcpp::symbols::ClassTypeSymbol* baseClass = static_cast<sngcpp::symbols::ClassTypeSymbol*>(baseClassType);
35 for (sngcpp::symbols::FunctionSymbol* baseFun : baseClass->VirtualFunctions())
36 {
37 if (Overrides(derivedFun, baseFun))
38 {
39 derivedFun->AddOverridden(baseFun);
40 baseFun->AddOverride(derivedFun);
41 }
42 }
43 ResolveOverrideSets(derivedFun, baseClass);
44 }
45 }
46 }
47
48 void CollectPureVirtualFunctions(sngcpp::symbols::ClassTypeSymbol* cls, std::std::unordered_set<sngcpp::symbols::FunctionSymbol*>&pureVirtualFunctions)
49 {
50 for (sngcpp::symbols::TypeSymbol* baseClassType : cls->BaseClasses())
51 {
52 if (baseClassType->IsClassTypeSymbol())
53 {
54 sngcpp::symbols::ClassTypeSymbol* baseClass = static_cast<sngcpp::symbols::ClassTypeSymbol*>(baseClassType);
55 CollectPureVirtualFunctions(baseClass, pureVirtualFunctions);
56 }
57 }
58 for (sngcpp::symbols::FunctionSymbol* virtualFun : cls->VirtualFunctions())
59 {
60 if (virtualFun->IsPureVirtual())
61 {
62 pureVirtualFunctions.insert(virtualFun);
63 }
64 }
65 }
66
67 void AddOverriddenPureVirtualFunctions(sngcpp::symbols::ClassTypeSymbol* cls, std::std::unordered_set<sngcpp::symbols::FunctionSymbol*>&pureVirtualFunctions,
68 std::std::unordered_set<sngcpp::symbols::FunctionSymbol*>&overriddenPureVirtuals)
69 {
70 for (sngcpp::symbols::TypeSymbol* baseClassType : cls->BaseClasses())
71 {
72 if (baseClassType->IsClassTypeSymbol())
73 {
74 sngcpp::symbols::ClassTypeSymbol* baseClass = static_cast<sngcpp::symbols::ClassTypeSymbol*>(baseClassType);
75 AddOverriddenPureVirtualFunctions(baseClass, pureVirtualFunctions, overriddenPureVirtuals);
76 }
77 }
78 for (sngcpp::symbols::FunctionSymbol* pureVirtualFunction : pureVirtualFunctions)
79 {
80 for (sngcpp::symbols::FunctionSymbol* virtualFun : cls->VirtualFunctions())
81 {
82 if (Overrides(virtualFun, pureVirtualFunction))
83 {
84 overriddenPureVirtuals.insert(pureVirtualFunction);
85 }
86 }
87 }
88 }
89
90 void ResolveOverrideSets(sngcpp::symbols::ClassTypeSymbol* cls)
91 {
92 std::unordered_set<sngcpp::symbols::FunctionSymbol*> pureVirtualFunctions;
93 CollectPureVirtualFunctions(cls, pureVirtualFunctions);
94 std::unordered_set<sngcpp::symbols::FunctionSymbol*> overriddenPureVirtuals;
95 AddOverriddenPureVirtualFunctions(cls, pureVirtualFunctions, overriddenPureVirtuals);
96 for (sngcpp::symbols::FunctionSymbol* overriddenPureVirtual : overriddenPureVirtuals)
97 {
98 pureVirtualFunctions.erase(overriddenPureVirtual);
99 }
100 if (!pureVirtualFunctions.empty())
101 {
102 cls->SetAbstract();
103 }
104 for (sngcpp::symbols::FunctionSymbol* derivedFun : cls->VirtualFunctions())
105 {
106 ResolveOverrideSets(derivedFun, cls);
107 }
108 }
109
110 void ResolveOverrideSets(const std::std::unordered_set<sngcpp::symbols::ClassTypeSymbol*>&classes)
111 {
112 for (sngcpp::symbols::ClassTypeSymbol* cls : classes)
113 {
114 ResolveOverrideSets(cls);
115 }
116 }
117
118 } }