字典树(Trie)原理及实现(Java)

2023-04-25 17:58:52

  字典树(Trie)—— 存储字符串作为key的数据结构。因为Trie存储字符串时特殊的结构,它和BST、HashTable等相比具有独特的优势(如前缀匹配等)。
满足条件
  每个节点只存储一个字母;
  位于同一分支上的字母组成字符串(尾节点需要单独标记)。


实现方式
  Trie有三种实现方式:
  数组 —— 字母有限,每个节点可以使用固定大小的数组作为child节点的索引;
  BST —— 使用BST作为child节点的索引,与数组相比可以节约一定的空间;
  HashTable —— 使用HashTable作为child节点的索引,具有常量时间的search操作。


Operation

    public void add(String w); // 将字符串w添加进Trie中存储
    public List<String> allString(); // 返回Trie中存储的所有字符串
    public String longestPrefixOf(String w); // 返回Trie中和w匹配的最长的前缀
    public List<String> stringsWithPrefix(String prefix); // 返回所有带有前缀prefix的字符串

java代码
  以BST为例存储Trie的child节点(图片空间有限,只选取了前几个节点)
在这里插入图片描述
  Trie.java

import java.util.ArrayList;
import java.util.List;

public class Trie {
    private Node root;

    Trie() {
        root = new Node( '0', null);
    }

    public void add(String w) {
        /* 将字符串w添加进Trie中存储
         */
        if (root.getNext() == null) {
            root.setNext(w.charAt(0));
        }
        addHelper(root.getNext(), w, 0);
    }

    private void addHelper(BST curLevel, String w, int curChar) {
        char c = w.charAt(curChar);
        curChar += 1;
        if (!curLevel.contain(c)){
            curLevel.add(c);
        }
        Node nextNode = curLevel.get(c);
        if (curChar >= w.length()) {
            nextNode.setEnd();
            return;
        }
        if (nextNode.getNext() == null) {
            nextNode.setNext(w.charAt(curChar));
        }
        addHelper(nextNode.getNext(), w, curChar);
    }

    public List<String> allString() {
        /* 返回Trie中存储的所有字符串
         */
        List<String> res = new ArrayList<>();
        addStringHelper(root.getNodes(), res, "");
        return res;
    }

    private void addStringHelper(List<Node> nodes, List<String> res, String prev) {
        if (nodes == null) {
            return;
        } else {
            for (Node node : nodes) {
                char c = node.getKey();
                if (node.getIsEnd()) {
                    res.add(prev + c);
                }
                addStringHelper(node.getNodes(), res, prev + c);
            }
        }
    }

    public String longestPrefixOf(String w) {
        /* 返回Trie中和w匹配的最长的前缀
         */
        return longestHelper(root.getNodes(), w, 0, "");
    }

    private String longestHelper(List<Node> nodes, String w, int curChar, String res) {
        if (nodes != null && curChar < w.length()) {
            for (Node node : nodes) {
                char c = node.getKey();
                if (c == w.charAt(curChar)) {
                    return longestHelper(node.getNodes(), w, curChar + 1, res + c);
                }
            }
        }
        return res;
    }

    public List<String> stringsWithPrefix(String prefix) {
        /* 返回所有带有前缀prefix的字符串
         */
        List<String> res = new ArrayList<>();
        stringPrefixHelper(root.getNodes(), prefix, 0, prefix, res);
        return res;
    }

    private void stringPrefixHelper(List<Node> nodes, String w, int curChar, String prev, List<String> res) {
        if (nodes == null) {
            return;
        } else {
            for (Node node : nodes) {
                /* 有两种情况
                 * 此时curChar没有到w的末端,继续在同样的node中search
                 * curChar到达了w的末端,node也是字符串的结尾,将其放入res中
                 */
                if (curChar < w.length()) {
                    if (node.getKey() == w.charAt(curChar)) {
                        stringPrefixHelper(node.getNodes(), w, curChar + 1, prev, res);
                    }
                } else {
                    String newPrev = prev + node.getKey();
                    if (node.getIsEnd()) {
                        res.add(newPrev);
                        stringPrefixHelper(node.getNodes(), w, curChar, newPrev, res);
                    }
                    stringPrefixHelper(node.getNodes(), w, curChar, newPrev, res);
                }
            }
        }
    }
}

  Node.java

import java.util.List;

public class Node {
    private char key;
    private boolean isEnd;
    private BST next;  // 使用BST存储child

    Node(char key, BST next) {
        this.key = key;
        this.next = next;
        this.isEnd = false;
    }

    public char getKey() {
        return key;
    }

    public BST getNext() {
        return next;
    }

    public List<Node> getNodes() {
        if (next == null) {
            return null;
        } else {
            return next.getAllNodes();
        }
    }

    public boolean getIsEnd() {
        return this.isEnd;
    }

    public void setEnd() {
        this.isEnd = true;
    }

    public void setNext(char c) {
        next = new BST(c);
    }
}

  BST.java

import java.util.ArrayList;
import java.util.List;

public class BST {
    private BSTNode root;

    BST(char key) {
        root = new BSTNode(key);
    }

    private class BSTNode {
        private char key;
        private Node childNode;
        private BSTNode left;
        private BSTNode right;

        BSTNode(char key) {
            this.key = key;
            this.childNode = new Node(key, null);
            this.left = null;
            this.right = null;
        }
    }

    public List<Node> getAllNodes() {
        /* 返回BST中存储的所有Nodes
         */
        List<Node> res = new ArrayList<>();
        getAllNodesHelper(root, res);
        return res;
    }

    private void getAllNodesHelper(BSTNode curNode, List<Node> res) {
        if (curNode == null) {
            return;
        } else {
            res.add(curNode.childNode);
            getAllNodesHelper(curNode.left, res);
            getAllNodesHelper(curNode.right, res);
        }
    }

    public Node get(char c) {
        /* 返回当前BST中和c具有相同key的BSTNode存储的节点;
         */
        return getHelper(root, c);
    }

    private Node getHelper(BSTNode curNode, char c) {
        if (curNode == null) {
            return null;
        } else if (curNode.key == c) {
            return curNode.childNode;
        } else if (curNode.key < c) {
            return getHelper(curNode.right, c);
        } else {
            return getHelper(curNode.left, c);
        }
    }

    public void add(char c) {
        /* 将c添加进BST中
         */
        root = addHelper(root, c);
    }

    private BSTNode addHelper(BSTNode curNode, char c) {
        if (curNode == null) {
            return new BSTNode(c);
        } else if (curNode.key < c) {
            curNode.right = addHelper(curNode.right, c);
        } else {
            curNode.left = addHelper(curNode.left, c);
        }
        return curNode;
    }

    public boolean contain(char c) {
        /* BST中是否存在和c相同key的节点
         */
        return containHelper(root, c);
    }

    private boolean containHelper(BSTNode curNode, char c) {
        if (curNode == null) {
            return false;
        } else if (curNode.key == c) {
            return true;
        } else if (curNode.key < c) {
            return containHelper(curNode.right, c);
        } else {
            return containHelper(curNode.left, c);
        }
    }
}

  Test.java

public class Test {
    public static void main(String[] args) {
        /* 对Trie进行测试,测试样例有限,所以构造的Trie也可能存在问题
         */
        Trie t = new Trie();
        t.add("same");
        t.add("saple");
        t.add("apple");
        t.add("test");
        t.add("saddddd");
        System.out.println(t.allString());
        System.out.println(t.longestPrefixOf("sample"));
        System.out.println(t.stringsWithPrefix("s"));
    }
}



To be a sailor of the world bound for all ports.
  • 作者:carpe~diem
  • 原文链接:https://blog.csdn.net/weixin_43405649/article/details/124613440
    更新时间:2023-04-25 17:58:52