2007-11-24

nth_element 算法注解

nth_element是STL提供的一个算法,用于找出序列中的第n大元素。这个算法涉及下面4个辅助函数:

  • _Nth_element
  • _Unguarded_partition
  • _Median
  • _Med3

下面是我对STL源代码的注释。

/********************************************************************************************
_Nth_element

使序列中第n大的元素位于第n个位子上,使用 < 比较元素。
  基本思想:
  (1) 找到一个包含n的足够小的区间[a,b),使得[a,b)作为一个大粒度的元素处于序列的有序位置。
  (2) 对[a,b)部分进行排序。
nth_element查找区间的方式体现了二分(折半)查找的思想。它的核心是ungarded_partition算法,ungarded_partition算法能够找出随机序列的近似的位置中间值。
*********************************************************************************************/

template<class _RanIt> inline
void _Nth_element(_RanIt _First, _RanIt _Nth, _RanIt _Last)
{  // order Nth element, using operator<
  _DEBUG_RANGE(_First, _Last);

  /* 逐步缩小[_First, _Last)的区间,直至尺寸小到可以直接对局部进行排序。
  _ISORT_MAX 是一个阀值,在<algorithm>中被定义为32,表示适于插入法排序的序列的最大长度。*/

  for (; _ISORT_MAX < _Last - _First; )
  {  // divide and conquer, ordering partition containing Nth

    /* 把序列的当前部分划分为有序的三份。*/
    pair<_RanIt, _RanIt> _Mid =
      _Unguarded_partition(_First, _Last);

    /* 确定下一步搜索区间。*/
    if (_Mid.second <= _Nth)
      _First = _Mid.second;
    else if (_Mid.first <= _Nth)
      return;  // Nth inside fat pivot, done
    else
      _Last = _Mid.first;
  }

  /* 对[_First, _Last)排序,第n大元素就到了第n个位置上。*/
  _Insertion_sort(_First, _Last);  // sort any remainder
}

/********************************************************************************************
_Unguarded_partition

把序列划分为有序的三份:[_First, Mid.First), [Mid.First, Mid.Second), [Mid.Second, _Last),其中[Mid.First, Mid.Second)内的元素相等。
注意,不是等分:
(1) 不保证[Mid.First, Mid.Second)在位置上处于序列的中间。
(2) 不保证[Mid.First, Mid.Second)在元素值上处于序列的中间。
(3) [_First, Mid.First)和[Mid.Second, _Last)的长度有可能为0。
(4) [_First, Mid.First)中的所有元素(若有)小于[Mid.First, Mid.Second)中的任一元素。[Mid.Second, _Last)中的则大于。

********************************************************************************************/
template<class _RanIt> inline
pair<_RanIt, _RanIt> _Unguarded_partition(_RanIt _First, _RanIt _Last)
{  // partition [_First, _Last), using operator<
  _RanIt _Mid = _First + (_Last - _First) / 2;  // sort median to _Mid

  /* 选择序列的位置平均值(近似),作为枢轴值(pivot)。*/
  _Median(_First, _Mid, _Last - 1);

  /* [_Pfirst, _Plast) 就是要找的区间。*/
  _RanIt _Pfirst = _Mid;
  _RanIt _Plast = _Pfirst + 1;

  /* 以枢轴值为起点,向序列的两头扫描,把相邻的与枢轴值相等的元素合并到[_Pfirst, _Plast)中。
  宏 _DEBUG_LT(x,y) 定义为 ((x)<(y))。
  值得注意的是,下面的代码用两次 < 比较实现 == 比较。*/

  while (_First < _Pfirst
    && !_DEBUG_LT(*(_Pfirst - 1), *_Pfirst)
    && !(*_Pfirst < *(_Pfirst - 1)))
    --_Pfirst;
  while (_Plast < _Last
    && !_DEBUG_LT(*_Plast, *_Pfirst)
    && !(*_Pfirst < *_Plast))
    ++_Plast;

  /*  分别以[_Pfirst, _Plast)的起、止位置为起点,向序列的两头扫描。*/
  
  _RanIt _Gfirst = _Plast;
  _RanIt _Glast = _Pfirst;

  for (; ; )
  {  // partition
    for (; _Gfirst < _Last; ++_Gfirst)
      if (_DEBUG_LT(*_Pfirst, *_Gfirst))
        ;
      else if (*_Gfirst < *_Pfirst)
        break;
      else
        std::iter_swap(_Plast++, _Gfirst);
    /* 除非 _Gfirst == _Last,否则 *_Gfirst 应该调到[_Pfirst, _Plast)的前面。*/
      
    for (; _First < _Glast; --_Glast)
      if (_DEBUG_LT(*(_Glast - 1), *_Pfirst))
        ;
      else if (*_Pfirst < *(_Glast - 1))
        break;
      else
        std::iter_swap(--_Pfirst, _Glast - 1);
    /* 除非 _Glast == _First,否则 *(_Glast-1) 应该调到[_Pfirst, _Plast)的后面。*/

    if (_Glast == _First && _Gfirst == _Last)
      return (pair<_RanIt, _RanIt>(_Pfirst, _Plast));

    /* 如果 *_Gfirst 和 *(_Glast-1) 都需要调到[_Pfirst, _Plast)的对面,交换它俩就行了。
    如果只有其中一个(记为G)要调整,那就交换它和[_Pfirst, _Plast)——实际上要复杂些,有两种情况:
    (1) G与[_Pfirst, _Plast)相邻。这时只需要交换G和其对面的区间边界。
    (2) G与[_Pfirst, _Plast)之间有其它元素。这时需要三方交换(两次):第一次交换使枢轴区间移动一个单位(通过交换区间的内部边界和对面的外部边界来实现),第二次交换完成调整任务。*/

    if (_Glast == _First)
    {  // no room at bottom, rotate pivot upward
      if (_Plast != _Gfirst)
        std::iter_swap(_Pfirst, _Plast);
      ++_Plast;
      std::iter_swap(_Pfirst++, _Gfirst++);
    }
    else if (_Gfirst == _Last)
    {  // no room at top, rotate pivot downward
      if (--_Glast != --_Pfirst)
        std::iter_swap(_Glast, _Pfirst);
      std::iter_swap(_Pfirst, --_Plast);
    }
    else
      std::iter_swap(_Gfirst++, --_Glast);
  }
}

/*******************************************************************************************
_Median

找序列的近似位置平均值。
********************************************************************************************/

template<class _RanIt> inline
void _Median(_RanIt _First, _RanIt _Mid, _RanIt _Last)
{  // sort median element to middle
  if (40 < _Last - _First)
  {  // median of nine
    size_t _Step = (_Last - _First + 1) / 8;
    _Med3(_First, _First + _Step, _First + 2 * _Step);
    _Med3(_Mid - _Step, _Mid, _Mid + _Step);
    _Med3(_Last - 2 * _Step, _Last - _Step, _Last);
    _Med3(_First + _Step, _Mid, _Last - _Step);
  }
  else
    _Med3(_First, _Mid, _Last);
}

/********************************************************************************************
_Med3

对三个元素排序。
********************************************************************************************/

template<class _RanIt> inline
void _Med3(_RanIt _First, _RanIt _Mid, _RanIt _Last)
{  // sort median of three elements to middle
  if (_DEBUG_LT(*_Mid, *_First))
    std::iter_swap(_Mid, _First);
  if (_DEBUG_LT(*_Last, *_Mid))
    std::iter_swap(_Last, _Mid);
  if (_DEBUG_LT(*_Mid, *_First))
    std::iter_swap(_Mid, _First);
}





    评论

  • 学习了,不错,感谢
  • 学习了,不错,感谢
  • nth_element的直接意图是返回第n大元素的位置,它还有一个副作用,就是能保证nth前面的元素都小于nth,所以,很容易进一步处理得到\"前n大元素排名\"。