首页 > C/C++语言 > C/C++数据结构 > 三路归并排序算法
2006
05-28

#include <stdlib.h>
int mergesortT(int p[], int n);
extern int insertsort(int p[], int n);
static int merge(int work[], int swap[], int l, int m, int n, int flag);
/*
* 归并排序算法在 1938 年由 IBM 发明并在电动整理机上实现。
* 在 1945 年由 J. von Neumann 首次在 EDVAC 计算机上实现。
* 稳定,需要与序列同等大小的辅助空间。这里实现的是三路归并算法。
*/

#define IN    1
#define OUT    0
#define M    9 /* 启始路段长度 */

int mergesortT(int p[], int n)
{
   int op=0;
   int * work=p;
   int * swap;
   int i,j,m;
   int flag=OUT; /* 对换标志 */

   if (n<=16)
       return insertsort(work,n);
   swap=(int*)calloc(n,sizeof(int));
   if (swap==NULL)
       return 0;

   /* i 是经过插入排序的元素个数和未排序元素的开始位置 */
   for(i=0;i+M<=n;i+=M)    
       op+=insertsort(work+i,M);
   if (i<n)
       op+=insertsort(work+i,n-i);
   for(i=M; i<n; i*=3,flag^=1) { /* i 为路段长度 */
       m=i*3; /* m 为路段长度乘以归并的路数 */
       /* j 是已经归并路段的元素个数和未归并路段元素的开始位置 */
       for(j=0;j+m<=n;j+=m)
           op+=merge(work+j,swap+j,i,i<<1,m,flag);
       if (j+(i<<1)<n)
           op+=merge(work+j,swap+j,i,i<<1,n-j,flag);
       else if (j+i<n)
           op+=merge(work+j,swap+j,i,n-j,n-j,flag);
       else if (j<n)
           op+=merge(work+j,swap+j,n-j,n-j,n-j,flag);
   }
   if (flag==IN)
       op+=merge(work,swap,n,n,n,flag);

   free(swap);
   return op;
}
/*
* 三路归并过程。
*/
static int merge(int work[],    /* 工作空间,就是要归并的列表 */
          int swap[],    /* 交换空间,不小于工作空间 */
          int l, /* 前面列表长度和中间列表的开始位置 */
          int m,    /* 前两个列表长度和后面列表的开始位置 */
          int n, /* 三个列表总长度 */
          int flag) /* 换入换出标志 */
{
   int *src, *dest;
   int i=0, j=l, k=m, t=0;

   if (flag==OUT) {
       src=work;
       dest=swap;
   } else { /* flag==IN */
       src=swap;
       dest=work;
   }

   while (i<l && j<m && k<n)
       if (src <= src[j] && src <= src[k])
           dest[t++] = src[i++];
       else if (src[j] <= src[k])
           dest[t++] = src[j++];
       else
           dest[t++] = src[k++];

   while (i<l && j<m)
       if (src <= src[j])
           dest[t++] = src[i++];
       else
           dest[t++] = src[j++];
   while (j<m && k<n)
       if (src[j] <= src[k])
           dest[t++] = src[j++];
       else
           dest[t++] = src[k++];
   while (i<l && k<n)
       if (src <= src[k])
           dest[t++] = src[i++];
       else
           dest[t++] = src[k++];

   while (i<l)
       dest[t++] = src[i++];
   while (j<m)
       dest[t++] = src[j++];
   while (k<n)
       dest[t++] = src[k++];

   return n;
}


留下一个回复