找最近点对问题-分治算法的应用

    技术2022-05-19  25

        分治算法的基本思想是:     分(divide):递归求解子问题,即:分解+求解,将问题分解为k个方便求解的小问题。     为什么说是递归求解呢,这里可以看作将一个问题分2个子问题,如果2个子问题还是大,再继续分成4个子问题,直到分解到能方便求解的小问题。也就是说分治算法是含有2个以上的递归运算,只有一个递归的例程不能算做分治算法。     治(conquer):从子问题构建原问题的解。     对于分治,最长用到的复杂度分析情况为:     T(N) = aT(N/b) + O(N^k)     当a=b^k时, T(N)=O(N^k*logN)     比如,非常常见的二分法: T(N) = 2T(N/2) + O(N) ,此时a=2, b=2, k=1 即a=b^k,所以 该算法的复杂度为 O(NlogN)     从另一个角度,反向分析问题,如果我们希望得到一个O(NlogN)的算法,那么就需要保证附加工作为O(N),这是一个非常非常关键的利用分治算法解决问题的入手点!!!     而且这个复杂度也是大多数分治算法问题的情况。 当然还有另外两种>, < 情况,就参考书上的详细讲解吧。         下面以一个例子详细介绍如何应用分治算法。     最近点对问题:给定平面上的N个点,找出距离最近的两个点。     对于该问题,算法过程并不算复杂,但要想编程实现,需要克服不少细节问题。     首先应该实现Point类:      Point.h: #include <iostream> class Point { public:     Point(double x, double y);     double getx() const     {         return m_x;     }     double gety() const     {         return m_y;     }     friend std::ostream& operator<<(std::ostream& os, const Point& p);     private:     double m_x;     double m_y; };

        Point.cpp: #include "Point.h"

    Point::Point(double x, double y) : m_x(x), m_y(y) { }

    std::ostream& operator<<(std::ostream& os, const Point& p) {     os<<"P("<<p.getx()<<","<<p.gety()<<")";     return os; }

          这里有几点要注意:重载<<,友元函数的应用,初始化的位置。不多介绍了,代码很简单,上点心看看就好。     接下来,是解决该问题的具体过程:     1. 最简单的解法就是蛮力法:把每两个点的距离求出来,然后找出最小值即可。        虽然,该算法很简单,但编程实现时,从该笨方法起步会有一个很好的过渡,不至于编码难度过陡。        于是我们需要:计算两点距离的函数Distance, 求最近点对函数FindShortPair, 一组点Point p[]用以测试, 打印点函数PrintPoints用于观察。        代码如下:      上面的代码完成了很多该问题算法的外围工作,最重要的是提供了测试环境。注意const的应用,double,int型的定义。 #include "Point.h" #include <cmath>

    double Distance(const Point& s, const Point& t) {     double squarex = ( s.getx() - t.getx() ) * ( s.getx() - t.getx() );     double squarey = ( s.gety() - t.gety() ) * ( s.gety() - t.gety() );     return sqrt( squarex + squarey ); }

    void FindShortPair(const Point * p, int num) {//can only find one of the shortest path     double distance=Distance(p[0], p[1]);     int start = 0;     int end = 1;     for (int i=0; i<num; i++)     {         for (int j=i+1; j<num; j++)         {             if ( Distance(p[i], p[j]) < distance )             {                 distance = Distance(p[i], p[j]);                 start = i;                 end = j;             }         }     }     std::cout << "The shortest pair is: P" << start+1 << ", P" << end+1 <<""<<std::endl;     std::cout << "the distance is: " << distance<<std::endl; }

    void PrintPoints(const Point * p, int num) {     for (int i=0; i<num; i++)     {         std::cout << p[i] << " ";     }     std::cout << std::endl; }

    int main(int argc, const char** argv) {     Point p[] = {Point(2,3), Point(4,3), Point(4,6), Point(5,7), Point(4,3)};     int size = sizeof(p)/sizeof(p[0]);     PrintPoints(p, size);     FindShortPair(p, size);

        return 0; }

        2. 蛮力算法复杂度很明显为O(N^2)不理想。如果是O(N*logN)就好多了,下面介绍分治算法在解决该问题的具体应用过程。        假设平面上的点按x排序好了,这样最多增加O(N*logN),这再整个算法来看并没有增加复杂度级别。        排好序后,可以划一条垂线,把点集分成两半:PL和PR。于是最近点对或者在PL中,或者在PR中,或者PL,PR各有一点。        把三种距离情况定义为dL, dR, dC.                       其中dL, dR可以递归求解,于是问题就变为计算dC。 根据上面红色字解释,由于我们希望得到O(N*logN)的解,因此必须能够仅仅多花O(N)的附加工作计算dC。        另s=min(dL, dR). 通过观察能得出结论:如果dC<s,即dC对s有所改进,则只需计算dC。如果dC满足这样的条件,则决定dC的两点必然在分割线的s距离之内,称之为带(strip)        否则不可能满足dC<s, 于是缩小了需要考虑的点的范围。                      如果是均匀分布的点集,则能证明出在该带中平均只有O(sqrt(N))个点,(注:书上这么写的,我也不会证,先记下这个理论吧)。因此,对这些点运用蛮力法可以在O(N)时间内完成。        于是过程为:        double FindShortPairDC(const Point* p, int num)   //DC代表divide and conquer,分治        {            if (num <= 3) //也许您认为,递归到2个点时,才应该返回距离。但如果为3个点,可能会出现PL有2个点,PR有1个点的情况,这时dR会无法计算,所以3个点就要蛮力计算返回。                 return EnumShortestPair(p, num);            mid = (num+1)/2;            dL = FindShortPairDC(p, mid);            dR = FindShortPairDC(p+mid, num-mid);            s=min(dL, dR)             for (i=0; i<stripPointNum; i++)               for (j=i+1; j<stripPointNum; j++)                  if (dist(pi, pj) < s)                      s = dist(pi, pj);            return s;         }        代码实现:注意其中STL中Sort算法的应用方法。 #include "Point.h" #include <cmath> #include <algorithm>

    double Distance(const Point& s, const Point& t) {     double squarex = ( s.getx() - t.getx() ) * ( s.getx() - t.getx() );     double squarey = ( s.gety() - t.gety() ) * ( s.gety() - t.gety() );     return sqrt( squarex + squarey ); }

    bool ComparePoint(const Point& p1, const Point& p2) {     return (p1.getx() < p2.getx()); }

    double EnumShortestPair(const Point * p, int num) {//can only find one of the shortest path     double distance=Distance(p[0], p[1]);     int start = 0;     int end = 1;     for (int i=0; i<num; i++)     {         for (int j=i+1; j<num; j++)         {             if ( Distance(p[i], p[j]) < distance )             {                 distance = Distance(p[i], p[j]);                 start = i;                 end = j;             }         }     }     return distance; }

    double FindShortPairDC(const Point * p, int num) {//use divide and conquer algorithm to find the shortest path     double dL, dR, d, midXVal;     if (num < 2)     {         std::cout << "Need to input more than 2 points!"<< std::endl;         exit(1);     }     if (num < 4)     {         return EnumShortestPair(p, num);     }

        int mid = 0;     mid = (num+1)/2;     dL = FindShortPairDC(p, mid);     dR = FindShortPairDC(p+mid, (num-mid));     d = dL < dR ? dL : dR;

        midXVal = p[mid].getx();

        int stripStart = 0;     int stripEnd = num-1;     for (int i=0; i<num-1; i++)     {         if ( (p[i].getx() < midXVal-d) && (p[i+1].getx() >= midXVal-d) )             stripStart = i+1;         if ( (p[i].getx() <= midXVal+d) && (p[i+1].getx() > midXVal+d) )             stripEnd = i;     }        int start = 0;     int end = 1;     for (int i=stripStart; i<stripEnd; i++)     {         for (int j=i+1; j<stripEnd; j++)         {             if ( Distance(p[i], p[j]) < d )             {                 d = Distance(p[i], p[j]);                 start = i;                 end = j;             }         }     }     if (start!=0 || end!=0)         std::cout << "The shortest pair is: P" << start+1 << ", P" << end+1 <<""<<std::endl;

        std::cout << "the distance is: " << d <<std::endl;     return d; }

    void PrintPoints(const Point * p, int num) {     for (int i=0; i<num; i++)     {         std::cout << p[i] << " ";     }     std::cout << std::endl; }

    int main(int argc, const char** argv) {     Point p[] = {Point(2,3), Point(4,3), Point(4,6), Point(5,7), Point(4,3)};     int size = sizeof(p)/sizeof(p[0]);     PrintPoints(p, size);     std::sort(p, p+size, ComparePoint);     PrintPoints(p, size);     FindShortPairDC(p, size);

        return 0; } 这里需要解释一下,对于那条垂线的选取,代码并没有按照x坐标取中值,而是取点集的中间位置点 mid 表示PL点集的个数(包括垂线上点),(num-mid)表示PR点集的个数(可能包括垂线上点)。midXVal 为垂线对应的x值。stripStart -- stripEnd为在带中的点集范围,即p[stripStart]到p[stripEnd]      3. 2中的解法最坏情况复杂度仍会上升至O(N*logN), 为了得到O(N*logN)解法,我们仍然需要进行优化。通过进一步观察,我们发现,在带中的点,若进行按y坐标排序后,如果两个点y坐标相差s,则一定 不是最短点对,所以只需求y相差不大于s的点对距离即可。这样得到优化的函数:        double FindShortPairDC(const Point* p, int num)        {            if (num <= 3)                 return EnumShortestPair(p, num);            mid = (num+1)/2;            dL = FindShortPairDC(p, mid);            dR = FindShortPairDC(p+mid, num-mid);            s=min(dL, dR)             for (i=0; i<stripPointNum; i++)               for (j=i+1; j<stripPointNum; j++)               {                  if (pj.y - pi.y > s)                      break;                  if (dist(pi, pj) < s)                      s = dist(pi, pj);                }            return s;         } 代码实现:需要修改两处。 比较函数,增加y信息比较 bool ComparePoint(const Point& p1, const Point& p2) {     if( fabs(p1.getx() - p2.getx()) < 0.0000000001)         return (p1.gety() < p2.gety());     return (p1.getx() < p2.getx()); }

    求最短距离点对,内部循环信息 for (int i=stripStart; i<stripEnd; i++) {     for (int j=i+1; j<stripEnd; j++)     {         if (p[j].gety()-p[i].gety() > d)             break;         if ( Distance(p[i], p[j]) < d )         {             d = Distance(p[i], p[j]);             start = i;             end = j;         }     } } 分析下该解法的复杂度: 对于在带内的构成dC两个点pi,pj,这两点一定在一个s*2s的矩形内。否则y距离差便>s ______ |__|__| s   s    s 在左右两个s*s方形区域内,最多有4个点,如果再多,则必然有两个点距离<s, 这与s=min(dL, dR)矛盾。 所以在这样一个矩形内,最多存在8个点,亦即对于一个点pi,最多计算7个点与其距离。所以上面的内层循环 for (int j=i+1; j<stripEnd; j++) 最多执行7次,即该内层循环复杂度为O(1), 所以上面双层循环为O(N),可以用O(N)完成带区域内的点集最近点对查找。这样便满足了整体算法复杂度为O(N*logN)

     

     


    最新回复(0)