bktree.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. """
  2. This module implements Burkhard-Keller Trees (bk-tree). bk-trees
  3. allow fast lookup of words that lie within a specified distance of a
  4. query word. For example, this might be used by a spell checker to
  5. find near matches to a mispelled word.
  6. The implementation is based on the description in this article:
  7. http://blog.notdot.net/2007/4/Damn-Cool-Algorithms-Part-1-BK-Trees
  8. Licensed under the PSF license: http://www.python.org/psf/license/
  9. - Adam Hupp <adam@hupp.org>
  10. """
  11. class BKTree:
  12. def __init__(self, distfn, words):
  13. """
  14. Create a new BK-tree from the given distance function and
  15. words.
  16. Arguments:
  17. distfn: a binary function that returns the distance between
  18. two words. Return value is a non-negative integer. the
  19. distance function must be a metric space.
  20. words: an iterable. produces values that can be passed to
  21. distfn
  22. """
  23. self.distfn = distfn
  24. it = iter(words)
  25. root = next(it)
  26. self.tree = (root, {})
  27. for i in it:
  28. self._add_word(self.tree, i)
  29. def _add_word(self, parent, word):
  30. pword, children = parent
  31. d = self.distfn(word, pword)
  32. if d in children:
  33. self._add_word(children[d], word)
  34. else:
  35. children[d] = (word, {})
  36. def query(self, word, n):
  37. """
  38. Return all words in the tree that are within a distance of `n'
  39. from `word`.
  40. Arguments:
  41. word: a word to query on
  42. n: a non-negative integer that specifies the allowed distance
  43. from the query word.
  44. Return value is a list of tuples (distance, word), sorted in
  45. ascending order of distance.
  46. """
  47. def rec(parent):
  48. pword, children = parent
  49. d = self.distfn(word, pword)
  50. results = []
  51. if d <= n:
  52. results.append( (d, pword) )
  53. for i in range(d-n, d+n+1):
  54. child = children.get(i)
  55. if child is not None:
  56. results.extend(rec(child))
  57. return results
  58. # sort by distance
  59. return sorted(rec(self.tree))
  60. # http://en.wikibooks.org/wiki/Algorithm_implementation/Strings/Levenshtein_distance#Python
  61. def levenshtein(s, t):
  62. m, n = len(s), len(t)
  63. d = [range(n+1)]
  64. d += [[i] for i in range(1,m+1)]
  65. for i in range(0,m):
  66. for j in range(0,n):
  67. cost = 1
  68. if s[i] == t[j]: cost = 0
  69. d[i+1].append( min(d[i][j+1]+1, # deletion
  70. d[i+1][j]+1, #insertion
  71. d[i][j]+cost) #substitution
  72. )
  73. return d[m][n]
  74. def dict_words(dictfile="/usr/share/dict/american-english"):
  75. "Return an iterator that produces words in the given dictionary."
  76. with open(dictfile,'r') as f:
  77. file_contents = f.readlines()
  78. for i in file_contents:
  79. yield i.strip()
  80. if __name__ == "__main__":
  81. pass