1 // =================================
  2 // Copyright (c) 2020 Seppo Laakko
  3 // Distributed under the MIT license
  4 // =================================
  5 
  6 #include <sngcpp/binder/VirtualBinder.hpp>
  7 #include <sngcpp/symbols/ParameterSymbol.hpp>
  8 
  9 namespace sngcpp { namespace binder {
 10 
 11 bool Overrides(sngcpp::symbols::FunctionSymbol* derivedFunsngcpp::symbols::FunctionSymbol* baseFun)
 12 {
 13     if (derivedFun == baseFun) return false;
 14     if (derivedFun->Arity() != baseFun->Arity()) return false;
 15     if (derivedFun->Name() != baseFun->Name()) return false;
 16     int n = derivedFun->Arity();
 17     for (int i = 0; i < n; ++i)
 18     {
 19         sngcpp::symbols::ParameterSymbol* derivedParam = derivedFun->Parameters()[i];
 20         sngcpp::symbols::ParameterSymbol* baseParam = baseFun->Parameters()[i];
 21         sngcpp::symbols::TypeSymbol* derivedType = derivedParam->GetType();
 22         sngcpp::symbols::TypeSymbol* baseType = baseParam->GetType();
 23         if (derivedType->Id() != baseType->Id()) return false;
 24     }
 25     return true;
 26 }
 27 
 28 void ResolveOverrideSets(sngcpp::symbols::FunctionSymbol* derivedFunsngcpp::symbols::ClassTypeSymbol* parentClass)
 29 {
 30     for (sngcpp::symbols::TypeSymbol* baseClassType : parentClass->BaseClasses())
 31     {
 32         if (baseClassType->IsClassTypeSymbol())
 33         {
 34             sngcpp::symbols::ClassTypeSymbol* baseClass = static_cast<sngcpp::symbols::ClassTypeSymbol*>(baseClassType);
 35             for (sngcpp::symbols::FunctionSymbol* baseFun : baseClass->VirtualFunctions())
 36             {
 37                 if (Overrides(derivedFunbaseFun))
 38                 {
 39                     derivedFun->AddOverridden(baseFun);
 40                     baseFun->AddOverride(derivedFun);
 41                 }
 42             }
 43             ResolveOverrideSets(derivedFunbaseClass);
 44         }
 45     }
 46 }
 47 
 48 void CollectPureVirtualFunctions(sngcpp::symbols::ClassTypeSymbol* clsstd::std::unordered_set<sngcpp::symbols::FunctionSymbol*>&pureVirtualFunctions)
 49 {
 50     for (sngcpp::symbols::TypeSymbol* baseClassType : cls->BaseClasses())
 51     {
 52         if (baseClassType->IsClassTypeSymbol())
 53         {
 54             sngcpp::symbols::ClassTypeSymbol* baseClass = static_cast<sngcpp::symbols::ClassTypeSymbol*>(baseClassType);
 55             CollectPureVirtualFunctions(baseClasspureVirtualFunctions);
 56         }
 57     }
 58     for (sngcpp::symbols::FunctionSymbol* virtualFun : cls->VirtualFunctions())
 59     {
 60         if (virtualFun->IsPureVirtual())
 61         {
 62             pureVirtualFunctions.insert(virtualFun);
 63         }
 64     }
 65 }
 66 
 67 void AddOverriddenPureVirtualFunctions(sngcpp::symbols::ClassTypeSymbol* clsstd::std::unordered_set<sngcpp::symbols::FunctionSymbol*>&pureVirtualFunctions
 68     std::std::unordered_set<sngcpp::symbols::FunctionSymbol*>&overriddenPureVirtuals)
 69 {
 70     for (sngcpp::symbols::TypeSymbol* baseClassType : cls->BaseClasses())
 71     {
 72         if (baseClassType->IsClassTypeSymbol())
 73         {
 74             sngcpp::symbols::ClassTypeSymbol* baseClass = static_cast<sngcpp::symbols::ClassTypeSymbol*>(baseClassType);
 75             AddOverriddenPureVirtualFunctions(baseClasspureVirtualFunctionsoverriddenPureVirtuals);
 76         }
 77     }
 78     for (sngcpp::symbols::FunctionSymbol* pureVirtualFunction : pureVirtualFunctions)
 79     {
 80         for (sngcpp::symbols::FunctionSymbol* virtualFun : cls->VirtualFunctions())
 81         {
 82             if (Overrides(virtualFunpureVirtualFunction))
 83             {
 84                 overriddenPureVirtuals.insert(pureVirtualFunction);
 85             }
 86         }
 87     }
 88 }
 89 
 90 void ResolveOverrideSets(sngcpp::symbols::ClassTypeSymbol* cls)
 91 {
 92     std::unordered_set<sngcpp::symbols::FunctionSymbol*> pureVirtualFunctions;
 93     CollectPureVirtualFunctions(clspureVirtualFunctions);
 94     std::unordered_set<sngcpp::symbols::FunctionSymbol*> overriddenPureVirtuals;
 95     AddOverriddenPureVirtualFunctions(clspureVirtualFunctionsoverriddenPureVirtuals);
 96     for (sngcpp::symbols::FunctionSymbol* overriddenPureVirtual : overriddenPureVirtuals)
 97     {
 98         pureVirtualFunctions.erase(overriddenPureVirtual);
 99     }
100     if (!pureVirtualFunctions.empty())
101     {
102         cls->SetAbstract();
103     }
104     for (sngcpp::symbols::FunctionSymbol* derivedFun : cls->VirtualFunctions())
105     {
106         ResolveOverrideSets(derivedFuncls);
107     }
108 }
109 
110 void ResolveOverrideSets(const std::std::unordered_set<sngcpp::symbols::ClassTypeSymbol*>&classes)
111 {
112     for (sngcpp::symbols::ClassTypeSymbol* cls : classes)
113     {
114         ResolveOverrideSets(cls);
115     }
116 }
117 
118 } } // namespace sngcpp::binder