分治算法的基本思想是: 分(divide):递归求解子问题,即:分解+求解,将问题分解为k个方便求解的小问题。 为什么说是递归求解呢,这里可以看作将一个问题分2个子问题,如果2个子问题还是大,再继续分成4个子问题,直到分解到能方便求解的小问题。也就是说分治算法是含有2个以上的递归运算,只有一个递归的例程不能算做分治算法。 治(conquer):从子问题构建原问题的解。 对于分治,最长用到的复杂度分析情况为: T(N) = aT(N/b) + O(N^k) 当a=b^k时, T(N)=O(N^k*logN) 比如,非常常见的二分法: T(N) = 2T(N/2) + O(N) ,此时a=2, b=2, k=1 即a=b^k,所以 该算法的复杂度为 O(NlogN) 从另一个角度,反向分析问题,如果我们希望得到一个O(NlogN)的算法,那么就需要保证附加工作为O(N),这是一个非常非常关键的利用分治算法解决问题的入手点!!! 而且这个复杂度也是大多数分治算法问题的情况。 当然还有另外两种>, < 情况,就参考书上的详细讲解吧。 下面以一个例子详细介绍如何应用分治算法。 最近点对问题:给定平面上的N个点,找出距离最近的两个点。 对于该问题,算法过程并不算复杂,但要想编程实现,需要克服不少细节问题。 首先应该实现Point类: Point.h: #include <iostream> class Point { public: Point(double x, double y); double getx() const { return m_x; } double gety() const { return m_y; } friend std::ostream& operator<<(std::ostream& os, const Point& p); private: double m_x; double m_y; };
Point.cpp: #include "Point.h"
Point::Point(double x, double y) : m_x(x), m_y(y) { }
std::ostream& operator<<(std::ostream& os, const Point& p) { os<<"P("<<p.getx()<<","<<p.gety()<<")"; return os; }
这里有几点要注意:重载<<,友元函数的应用,初始化的位置。不多介绍了,代码很简单,上点心看看就好。 接下来,是解决该问题的具体过程: 1. 最简单的解法就是蛮力法:把每两个点的距离求出来,然后找出最小值即可。 虽然,该算法很简单,但编程实现时,从该笨方法起步会有一个很好的过渡,不至于编码难度过陡。 于是我们需要:计算两点距离的函数Distance, 求最近点对函数FindShortPair, 一组点Point p[]用以测试, 打印点函数PrintPoints用于观察。 代码如下: 上面的代码完成了很多该问题算法的外围工作,最重要的是提供了测试环境。注意const的应用,double,int型的定义。 #include "Point.h" #include <cmath>
double Distance(const Point& s, const Point& t) { double squarex = ( s.getx() - t.getx() ) * ( s.getx() - t.getx() ); double squarey = ( s.gety() - t.gety() ) * ( s.gety() - t.gety() ); return sqrt( squarex + squarey ); }
void FindShortPair(const Point * p, int num) {//can only find one of the shortest path double distance=Distance(p[0], p[1]); int start = 0; int end = 1; for (int i=0; i<num; i++) { for (int j=i+1; j<num; j++) { if ( Distance(p[i], p[j]) < distance ) { distance = Distance(p[i], p[j]); start = i; end = j; } } } std::cout << "The shortest pair is: P" << start+1 << ", P" << end+1 <<""<<std::endl; std::cout << "the distance is: " << distance<<std::endl; }
void PrintPoints(const Point * p, int num) { for (int i=0; i<num; i++) { std::cout << p[i] << " "; } std::cout << std::endl; }
int main(int argc, const char** argv) { Point p[] = {Point(2,3), Point(4,3), Point(4,6), Point(5,7), Point(4,3)}; int size = sizeof(p)/sizeof(p[0]); PrintPoints(p, size); FindShortPair(p, size);
return 0; }
2. 蛮力算法复杂度很明显为O(N^2)不理想。如果是O(N*logN)就好多了,下面介绍分治算法在解决该问题的具体应用过程。 假设平面上的点按x排序好了,这样最多增加O(N*logN),这再整个算法来看并没有增加复杂度级别。 排好序后,可以划一条垂线,把点集分成两半:PL和PR。于是最近点对或者在PL中,或者在PR中,或者PL,PR各有一点。 把三种距离情况定义为dL, dR, dC. 其中dL, dR可以递归求解,于是问题就变为计算dC。 根据上面红色字解释,由于我们希望得到O(N*logN)的解,因此必须能够仅仅多花O(N)的附加工作计算dC。 另s=min(dL, dR). 通过观察能得出结论:如果dC<s,即dC对s有所改进,则只需计算dC。如果dC满足这样的条件,则决定dC的两点必然在分割线的s距离之内,称之为带(strip) 否则不可能满足dC<s, 于是缩小了需要考虑的点的范围。 如果是均匀分布的点集,则能证明出在该带中平均只有O(sqrt(N))个点,(注:书上这么写的,我也不会证,先记下这个理论吧)。因此,对这些点运用蛮力法可以在O(N)时间内完成。 于是过程为: double FindShortPairDC(const Point* p, int num) //DC代表divide and conquer,分治 { if (num <= 3) //也许您认为,递归到2个点时,才应该返回距离。但如果为3个点,可能会出现PL有2个点,PR有1个点的情况,这时dR会无法计算,所以3个点就要蛮力计算返回。 return EnumShortestPair(p, num); mid = (num+1)/2; dL = FindShortPairDC(p, mid); dR = FindShortPairDC(p+mid, num-mid); s=min(dL, dR) for (i=0; i<stripPointNum; i++) for (j=i+1; j<stripPointNum; j++) if (dist(pi, pj) < s) s = dist(pi, pj); return s; } 代码实现:注意其中STL中Sort算法的应用方法。 #include "Point.h" #include <cmath> #include <algorithm>
double Distance(const Point& s, const Point& t) { double squarex = ( s.getx() - t.getx() ) * ( s.getx() - t.getx() ); double squarey = ( s.gety() - t.gety() ) * ( s.gety() - t.gety() ); return sqrt( squarex + squarey ); }
bool ComparePoint(const Point& p1, const Point& p2) { return (p1.getx() < p2.getx()); }
double EnumShortestPair(const Point * p, int num) {//can only find one of the shortest path double distance=Distance(p[0], p[1]); int start = 0; int end = 1; for (int i=0; i<num; i++) { for (int j=i+1; j<num; j++) { if ( Distance(p[i], p[j]) < distance ) { distance = Distance(p[i], p[j]); start = i; end = j; } } } return distance; }
double FindShortPairDC(const Point * p, int num) {//use divide and conquer algorithm to find the shortest path double dL, dR, d, midXVal; if (num < 2) { std::cout << "Need to input more than 2 points!"<< std::endl; exit(1); } if (num < 4) { return EnumShortestPair(p, num); }
int mid = 0; mid = (num+1)/2; dL = FindShortPairDC(p, mid); dR = FindShortPairDC(p+mid, (num-mid)); d = dL < dR ? dL : dR;
midXVal = p[mid].getx();
int stripStart = 0; int stripEnd = num-1; for (int i=0; i<num-1; i++) { if ( (p[i].getx() < midXVal-d) && (p[i+1].getx() >= midXVal-d) ) stripStart = i+1; if ( (p[i].getx() <= midXVal+d) && (p[i+1].getx() > midXVal+d) ) stripEnd = i; } int start = 0; int end = 1; for (int i=stripStart; i<stripEnd; i++) { for (int j=i+1; j<stripEnd; j++) { if ( Distance(p[i], p[j]) < d ) { d = Distance(p[i], p[j]); start = i; end = j; } } } if (start!=0 || end!=0) std::cout << "The shortest pair is: P" << start+1 << ", P" << end+1 <<""<<std::endl;
std::cout << "the distance is: " << d <<std::endl; return d; }
void PrintPoints(const Point * p, int num) { for (int i=0; i<num; i++) { std::cout << p[i] << " "; } std::cout << std::endl; }
int main(int argc, const char** argv) { Point p[] = {Point(2,3), Point(4,3), Point(4,6), Point(5,7), Point(4,3)}; int size = sizeof(p)/sizeof(p[0]); PrintPoints(p, size); std::sort(p, p+size, ComparePoint); PrintPoints(p, size); FindShortPairDC(p, size);
return 0; } 这里需要解释一下,对于那条垂线的选取,代码并没有按照x坐标取中值,而是取点集的中间位置点 mid 表示PL点集的个数(包括垂线上点),(num-mid)表示PR点集的个数(可能包括垂线上点)。midXVal 为垂线对应的x值。stripStart -- stripEnd为在带中的点集范围,即p[stripStart]到p[stripEnd] 3. 2中的解法最坏情况复杂度仍会上升至O(N*logN), 为了得到O(N*logN)解法,我们仍然需要进行优化。通过进一步观察,我们发现,在带中的点,若进行按y坐标排序后,如果两个点y坐标相差s,则一定 不是最短点对,所以只需求y相差不大于s的点对距离即可。这样得到优化的函数: double FindShortPairDC(const Point* p, int num) { if (num <= 3) return EnumShortestPair(p, num); mid = (num+1)/2; dL = FindShortPairDC(p, mid); dR = FindShortPairDC(p+mid, num-mid); s=min(dL, dR) for (i=0; i<stripPointNum; i++) for (j=i+1; j<stripPointNum; j++) { if (pj.y - pi.y > s) break; if (dist(pi, pj) < s) s = dist(pi, pj); } return s; } 代码实现:需要修改两处。 比较函数,增加y信息比较 bool ComparePoint(const Point& p1, const Point& p2) { if( fabs(p1.getx() - p2.getx()) < 0.0000000001) return (p1.gety() < p2.gety()); return (p1.getx() < p2.getx()); }
求最短距离点对,内部循环信息 for (int i=stripStart; i<stripEnd; i++) { for (int j=i+1; j<stripEnd; j++) { if (p[j].gety()-p[i].gety() > d) break; if ( Distance(p[i], p[j]) < d ) { d = Distance(p[i], p[j]); start = i; end = j; } } } 分析下该解法的复杂度: 对于在带内的构成dC两个点pi,pj,这两点一定在一个s*2s的矩形内。否则y距离差便>s ______ |__|__| s s s 在左右两个s*s方形区域内,最多有4个点,如果再多,则必然有两个点距离<s, 这与s=min(dL, dR)矛盾。 所以在这样一个矩形内,最多存在8个点,亦即对于一个点pi,最多计算7个点与其距离。所以上面的内层循环 for (int j=i+1; j<stripEnd; j++) 最多执行7次,即该内层循环复杂度为O(1), 所以上面双层循环为O(N),可以用O(N)完成带区域内的点集最近点对查找。这样便满足了整体算法复杂度为O(N*logN)