diff --git a/common/util/trie_map.h b/common/util/trie_map.h index 19e67a21bb..6b5b6193e8 100644 --- a/common/util/trie_map.h +++ b/common/util/trie_map.h @@ -13,7 +13,7 @@ class TrieMap { private: // TrieNode structure struct TrieNode { - std::unordered_map> children; + std::map> children; std::vector> elements; }; @@ -72,6 +72,11 @@ class TrieMap { // Remove the specified element from the TrieMap void remove(const std::shared_ptr& element) { remove_element(root, element); } + template + int remove_matching(F&& f) { + return remove_matching_elements(root.get(), f); + } + // Return the total number of elements stored in the TrieMap int size() const { int count = 0; @@ -104,6 +109,36 @@ class TrieMap { } } + /*! + * Remove elements where f(element) == true; + */ + template + int remove_matching_elements(TrieNode* node, F&& f) { + int erase_count = 0; + // remove from this level + for (auto it = node->elements.begin(); it != node->elements.end();) { + if (f(it->get())) { + it = node->elements.erase(it); + erase_count++; + } else { + ++it; + } + } + + // remove children + for (auto it = node->children.begin(); it != node->children.end();) { + erase_count += remove_matching_elements(it->second.get(), f); + // remove child if it's empty + if (it->second->children.empty() && it->second->elements.empty()) { + it = node->children.erase(it); + } else { + ++it; + } + } + + return erase_count; + } + // Recursive function to remove the specified element from the TrieMap bool remove_element(std::shared_ptr node, const std::shared_ptr& element) { // Remove the element if it exists at this node diff --git a/goalc/compiler/symbol_info.cpp b/goalc/compiler/symbol_info.cpp index 0f90467636..afbb7af57d 100644 --- a/goalc/compiler/symbol_info.cpp +++ b/goalc/compiler/symbol_info.cpp @@ -311,9 +311,13 @@ std::vector> SymbolInfoMap::get_all_symbols() const void SymbolInfoMap::evict_symbols_using_file_index(const std::string& file_path) { const auto standardized_path = file_util::convert_to_unix_path_separators(file_path); if (m_file_symbol_index.find(standardized_path) != m_file_symbol_index.end()) { + std::unordered_set sym_infos; for (const auto& symbol : m_file_symbol_index.at(standardized_path)) { - m_symbol_map.remove(symbol); + sym_infos.insert(symbol.get()); } + + m_symbol_map.remove_matching([&](SymbolInfo* info) { return sym_infos.count(info) != 0; }); + m_file_symbol_index.erase(standardized_path); } }