tlds.rs 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. use std::collections::HashMap;
  2. use std::sync::OnceLock;
  3. /// whois server overrides from Lists.toml ("io:whois.nic.io" style entries)
  4. #[derive(Debug, Clone, Default)]
  5. pub struct WhoisOverrides {
  6. map: HashMap<String, String>,
  7. }
  8. impl WhoisOverrides {
  9. pub fn get_server(&self, tld: &str) -> Option<&str> {
  10. self.map.get(&tld.to_lowercase()).map(|s| s.as_str())
  11. }
  12. }
  13. /// a named TLD list from Lists.toml
  14. struct NamedList {
  15. name: String,
  16. tlds: Vec<String>,
  17. }
  18. struct ParsedLists {
  19. lists: Vec<NamedList>,
  20. whois_overrides: WhoisOverrides,
  21. }
  22. /// parse a single entry: "tld" or "tld:whois_server"
  23. fn parse_entry(entry: &str) -> (String, Option<String>) {
  24. if let Some(pos) = entry.find(':') {
  25. (entry[..pos].to_string(), Some(entry[pos + 1..].to_string()))
  26. } else {
  27. (entry.to_string(), None)
  28. }
  29. }
  30. /// parse entries, pull out TLD names and whois overrides
  31. fn parse_list(entries: &[toml::Value], overrides: &mut HashMap<String, String>) -> Vec<String> {
  32. entries
  33. .iter()
  34. .filter_map(|v| v.as_str())
  35. .map(|entry| {
  36. let (tld, server) = parse_entry(entry);
  37. if let Some(s) = server {
  38. overrides.insert(tld.to_lowercase(), s);
  39. }
  40. tld
  41. })
  42. .collect()
  43. }
  44. static PARSED_LISTS: OnceLock<ParsedLists> = OnceLock::new();
  45. fn parsed_lists() -> &'static ParsedLists {
  46. PARSED_LISTS.get_or_init(|| {
  47. let raw: toml::Value = toml::from_str(include_str!("../Lists.toml"))
  48. .expect("Lists.toml must be valid TOML");
  49. let table = raw.as_table().expect("Lists.toml must be a TOML table");
  50. // Build list names in the order build.rs discovered them
  51. let ordered_names: Vec<&str> = env!("HOARDOM_LIST_NAMES").split(',').collect();
  52. let mut overrides = HashMap::new();
  53. let mut lists = Vec::new();
  54. for name in &ordered_names {
  55. if let Some(toml::Value::Array(arr)) = table.get(*name) {
  56. let tlds = parse_list(arr, &mut overrides);
  57. lists.push(NamedList {
  58. name: name.to_string(),
  59. tlds,
  60. });
  61. }
  62. }
  63. ParsedLists {
  64. lists,
  65. whois_overrides: WhoisOverrides { map: overrides },
  66. }
  67. })
  68. }
  69. /// list names from Lists.toml, in order
  70. pub fn list_names() -> Vec<&'static str> {
  71. parsed_lists()
  72. .lists
  73. .iter()
  74. .map(|l| l.name.as_str())
  75. .collect()
  76. }
  77. /// first list name (the default)
  78. pub fn default_list_name() -> &'static str {
  79. list_names().first().copied().unwrap_or("standard")
  80. }
  81. /// get TLDs for a list name (case insensitive), None if not found
  82. pub fn get_tlds(name: &str) -> Option<Vec<&'static str>> {
  83. let lower = name.to_lowercase();
  84. parsed_lists()
  85. .lists
  86. .iter()
  87. .find(|l| l.name == lower)
  88. .map(|l| l.tlds.iter().map(String::as_str).collect())
  89. }
  90. /// get TLDs for a list name, falls back to default if not found
  91. pub fn get_tlds_or_default(name: &str) -> Vec<&'static str> {
  92. get_tlds(name).unwrap_or_else(|| get_tlds(default_list_name()).unwrap_or_default())
  93. }
  94. /// the builtin whois overrides from Lists.toml
  95. pub fn whois_overrides() -> &'static WhoisOverrides {
  96. &parsed_lists().whois_overrides
  97. }
  98. pub fn apply_top_tlds(tlds: Vec<&'static str>, top: &[String]) -> Vec<&'static str> {
  99. let mut result: Vec<&'static str> = Vec::with_capacity(tlds.len());
  100. // first add the top ones in the order specified
  101. for t in top {
  102. let lower = t.to_lowercase();
  103. if let Some(&found) = tlds.iter().find(|&&tld| tld.to_lowercase() == lower) {
  104. if !result.contains(&found) {
  105. result.push(found);
  106. }
  107. }
  108. }
  109. // then add the rest
  110. for tld in &tlds {
  111. if !result.contains(tld) {
  112. result.push(tld);
  113. }
  114. }
  115. result
  116. }
  117. #[cfg(test)]
  118. mod tests {
  119. use super::*;
  120. #[test]
  121. fn test_parse_entry_bare() {
  122. let (tld, server) = parse_entry("com");
  123. assert_eq!(tld, "com");
  124. assert_eq!(server, None);
  125. }
  126. #[test]
  127. fn test_parse_entry_with_override() {
  128. let (tld, server) = parse_entry("io:whois.nic.io");
  129. assert_eq!(tld, "io");
  130. assert_eq!(server, Some("whois.nic.io".to_string()));
  131. }
  132. #[test]
  133. fn test_whois_overrides_populated() {
  134. let overrides = whois_overrides();
  135. // io should have an override since our Lists.toml has "io:whois.nic.io"
  136. assert!(overrides.get_server("io").is_some());
  137. // com should not (it has RDAP)
  138. assert!(overrides.get_server("com").is_none());
  139. }
  140. #[test]
  141. fn test_top_tlds_reorder() {
  142. let tlds = vec!["com", "net", "org", "ch", "de"];
  143. let top = vec!["ch".to_string(), "de".to_string()];
  144. let result = apply_top_tlds(tlds, &top);
  145. assert_eq!(result, vec!["ch", "de", "com", "net", "org"]);
  146. }
  147. #[test]
  148. fn test_top_tlds_missing_ignored() {
  149. let tlds = vec!["com", "net"];
  150. let top = vec!["swiss".to_string()];
  151. let result = apply_top_tlds(tlds, &top);
  152. assert_eq!(result, vec!["com", "net"]);
  153. }
  154. }