刷题_kmp算法字符串匹配(28.实现 strStr() )

/ 默认分类 / 0 条评论 / 960浏览

KMP算法学习笔记

1. kmp算法是干什么的?

KMP算法是一种改进的字符串匹配算法,由D.E.Knuth,J.H.Morris和V.R.Pratt提出的,因此人们称它为克努特—莫里斯—普拉特操作(简称KMP算法)。KMP算法的核心是利用匹配失败后的信息,尽量减少模式串与主串的匹配次数以达到快速匹配的目的。具体实现就是通过一个next()函数实现,函数本身包含了模式串的局部匹配信息。KMP算法的时间复杂度O(m+n)

2. 什么是字符串模式匹配

给定两个串S=“s1s2s3 …sn”和T=“t1t2t3 …tn”,在主串S中寻找子串T的过程叫做模式匹配,T称为模式。 简单点解释:

字符串模式匹配,也称子串的定位操作,通俗的说就是在一个主串中判断是否存在给定的子串(又称模式串),若存在,则返回匹配成功的索引。如:
主串:hizuohui
子串(模式):zuo
主串中包含子串"zuo",说明匹配成功,且返回的索引为:2

3. 最简单,最容易想到的解决模式匹配的方法-暴力破解

暴力破解即BF算法,是一种穷举的算法

    public int strStrBF(String haystack, String needle) {

        char[] origins = haystack.toCharArray();  //m
        char[] currents = needle.toCharArray();  //n

        int flag = 1;
        //总共需要比较的轮次数
        for (int i = 0; i < origins.length-currents.length+1; i++) {
            flag = 1;
            //每轮里面最多都只需要比较模式串的长度次
            for (int j = 0; j < currents.length; j++) {
                if(currents[j] != origins[j+i]){
                    flag = 0;
                    break;
                }
            }
            if(flag == 1) return i;
        }
        return -1;
    }

上面使用的是for循环实现,如果你在百度上查找BF字符串匹配,可能最多的会看到使用while来实现的,如下:

 public int strStrBF1(String haystack, String needle) {
        char[] origins = haystack.toCharArray();  //m
        char[] currents = needle.toCharArray();  //n

        //开始时指针都在起始位置
        int oi =0, ci =0;
        //只有当父串和模式串的指针所指都不是空的时候才会继续比较或回溯
        while (oi <= origins.length - 1 && ci <= currents.length - 1){
            if(origins[oi] == currents[ci]){
                oi++;
                ci++;
            }else {
                //回溯后加1,表示比原来进1位了,因为oi和ci都是同时进1指针的,所以相减后直接回到原位
                oi = oi - ci + 1;
                ci = 0;
            }
        }

        if(ci >= currents.length){
            return oi - ci;
        }else {
            return -1;
        }
    }


符串 abaabcacbc
模式 abaabm

其实KMP不是一种解决模式匹配的新算法,它只是对BF算法的超级优化,消除了BF算法中的冗余比较判断,可以理解为BF算法的跳级版,直接跳过了很多 比较阶段


Q1: KMP为什么需要使用到模式串的最长公共前后缀?

最长公共前后缀即为:字符串所有前缀的集合F和所有后缀的集合B的交集的最长字符串

子串 | 前缀 | 后缀 | 交集最长
---|---|---|--- a | 无 | 无 | 0 ab | a | b | 0 aba | a.ab | a,ba | 1 abaa | a,ab,aba | a,aa,baa | 1 abaab | a,ab,aba,abaa | b,ab,aab,baab | 2

按照上述关系我们可以得到本次模式匹配的部分匹配表(PMT):

0 | 1 | 2 | 3 | 4 | 5 ---|---|---|---|---|--- a | b | a | a | b | m -1|0|0|1|1|2

该模式串的PMT的第一行是模式串index,第二行是模式串,第三行是模式串每个位置的PMT中的值

ps:PMT中的对应的值数据是相应位置串的前一位的最长公共子串长度,所以可以确定的是每个长度大于等于2的模式串的前两个pmt的值一定是-1 , 0

当指针的位置到上图的位置时,如果使用的时BF算法,则i会直接回溯到1,


Q2: KMP相比于BF的时间复杂度

如果父串的长度为m,模式串的长度为n
KMP的时间复杂度为O(m+n) 而BF的时间复杂度为O(m*n)

    /**
     * 假设p长度大于等于2
     */
    public void getNextArr(String p, int[] next) {
        char[] pChars = p.toCharArray();
        next[0] = -1;
        next[1] = 0;
        for (int j = 2; j < pChars.length; j++) {

            int k = next[j - 1];
            while (k != -1 && pChars[j - 1] != pChars[k]) {
                k = next[k];
            }
            if (k == -1) {
                next[j] = 0;
            } else {
                next[j] = k + 1;
            }
        }
        System.out.println(Arrays.asList(next));
    }


    @Test
    public void test891089(){
        getNextArr("abcdabd",new int[7]);
    }
    @Test
    public void test89089(){
        int a = 1;
        int b;
        int c=1;
        if(a == 0 && ((c = 2) == 2)){
            System.out.println(1212);
        }
        System.out.println(c);
    }




1 1

leetcode运行
kmp

class Solution {
    public int strStr(String haystack, String needle) {
        
        if(haystack.length() == 0 && needle.length() == 0) return 0;
        if(haystack.length() == 0 && !(needle.length() == 0)) return -1;
        if(!(haystack.length() == 0) && needle.length() == 0) return 0;
        char[] superArr = haystack.toCharArray();
        char[] patternArr = needle.toCharArray();
        int[] next = getNextArr(needle);

        int i = 0,j = 0;
        while(i < superArr.length && j < patternArr.length){
            if(j == -1 || superArr[i] == patternArr[j]){
                i++;
                j++;
            }else{
                j = next[j];
            }
        }

        if(j >= patternArr.length){
            return i - j;
        }else{
            return -1;
        }
    }

    public int[] getNextArr(String patternStr){
        char[] patternArr = patternStr.toCharArray();
        int[] next = new int[patternStr.length()];
        int j = 0;
        int k = -1;
        next[j] = k;

        //开始循环计算每一个j处的k(注意这里是计算每次j+1处的k值)
        while (j < patternStr.length()-1){
            if(k == -1 || patternArr[j] == patternArr[k]){
                k++;
                j++;
                next[j] = k;
            }else{
                k = next[k];
            }
        }
        return next;
    }
}


或者
    public void getNextArr(String p, int[] next) {
        char[] pChars = p.toCharArray();
        next[0] = -1;
        next[1] = 0;
        for (int j = 2; j < pChars.length; j++) {

            int k = next[j - 1];
            while (k != -1 && pChars[j - 1] != pChars[k]) {
                k = next[k];
            }
            if (k == -1) {
                next[j] = 0;
            } else {
                next[j] = k + 1;
            }
        }
        System.out.println(Arrays.asList(next));
    }

bf

class Solution {
    public int strStr(String haystack, String needle) {
        char[] superArr = haystack.toCharArray();
        char[] patternArr = needle.toCharArray();
        int i = 0,j = 0;
        while (i < superArr.length && j < patternArr.length){
            //匹配上了
            if(superArr[i] == patternArr[j]){
                i++;
                j++;
            }
            //失配
            else{
                i = i - j + 1;
                j = 0;
            }
        }

        //匹配到出循环
        if(j >= patternArr.length){
            return i - j;
        }
        //最后仍未匹配上出循环
        else{
            return -1;
        }
        }
}

"""
获取next数组
"""


def get_next_arr(pattern_str):
    j, k, arrLen = 1, 0, len(pattern_str)
    print(j, arrLen)
    next_arr = [0] * arrLen
    # [0,0,0,0,,,,,]
    next_arr[0] = -1
    next_arr[1] = 0

    while j < arrLen - 1:
        if k == -1 or pattern_str[j] == pattern_str[k]:
            k += 1
            j += 1
            next_arr[j] = k
        else:
            k = next_arr[k]
    return next_arr


def kmp(super_str, pattern_str):
    if (super_str is "") & (pattern_str is ""):
        return 0
    if (super_str is "") & (super_str is not ""):
        return -1
    if (super_str is not "") & (pattern_str is ""):
        return 0

    i, j = 0, 0
    next_arr = get_next_arr(pattern_str)
    super_str_len, pattern_str_len = len(super_str), len(pattern_str)
    while i < super_str_len and j < pattern_str_len:
        if (j == -1) | (super_str[i] == pattern_str[j]):
            i += 1
            j += 1
        else:
            j = next_arr[j]
    return i-j if j >= pattern_str_len else -1



if __name__ == '__main__':
    print(kmp("","ll"))