社内コードコンペ - お題:最速なCIDRブロックマッチ判定 〜 安井の場合: バイナリサーチのあれとこれ〜
おさらい
- #1 ひろせの場合 - IP::CountryとAPRを使ってみた
- #2 安井の場合: バイナリサーチのあれとこれ ←今回
- #3
- #4
前回に引き続きコードコンペのお話で、今回は安井の出番です。 前回は導入だったせいもあり、あまり速いコードはできてきませんでしたが、さてさて今回はどうでしょうか。
このコードのウリ
「IPアドレスのマッチング」とは、言い替えれば「32ビット値の検索」です。そこで私の中で真っ先に頭をよぎったのは、バイナリサーチ(二分探索)という手法でした。今回のネタはデータ数が300個弱ということなので、 IP::Country のようにビット単位で二分木検索するよりも、ソート済みのリストからバイナリサーチしたほうが計算量は少なくて済むだろうと考えました。
コード解説
泥臭い処理(ファイルからCIDRブロックを読み込んでソート済みの配列を生成する部分)は末尾に掲載します。
ここでは、ソート済みのCIDRブロックリストから、IPアドレスを検索する関数を紹介します。(とはいってもある種お決まりのコードですけどね(^^;
cide_lookup()
char *cidr_lookup(void *ipaddr, int len)
{
int half;
int cmin = 0;
int cmax = listcount;
uint32_t addr;
if(len != 4)
return(NULL);
addr = ntohl(*(uint32_t *)ipaddr);
/* バイナリサーチで範囲を絞りこむ */
while(cmax - cmin > 7){
half = (cmin + cmax)>>1;
if(addr < cidrlist[half]){
cmax = half;
}else{
cmin = half;
}
}
/* 絞りこんだ結果((最大7個)の中からマッチするものを探す */
while(cmin<(cmax--)){
if((addr & masklist[cmax]) == cidrlist[cmax]){
return(typelist[cmax]);
}
}
return(NULL);
}
通常のバイナリサーチの実装では、検索終了条件を以下のようにすると思います。
- 目的の値が見付かったとき
- (cmin == cmax)になっても目的の値が見付からなかったとき
- 配列に格納されているのはCIDRブロックのネットワークアドレスなので
- 検索対象のIPアドレスとの同値データは存在しない
- ネットマスクとのANDをとらなければマッチしているか判定できない
- ネットワークアドレスが同じでネットマスク長が異なるケースが考えられる
- 192.168.0.0/16
- 192.168.0.0/17
- 192.168.0.0/18
- このような場合、192.168.0.1は(3)にマッチしなければいけない
アセンブラでも書いてみた
もう、だいぶ昔の話になりますが、アセンブラ(6502,Z80,68000)で遊んでいた時期がありました。ちょうどそのころ、バイナリサーチ、バブルソート、クイックソートなどの「アルゴリズム」と呼ばれるものにはじめて遭遇し、「これはすごい!」と純粋に感動していたことを覚えています。
その記憶が甦ったのか、なにを血迷ったのかわかりませんが、なぜかふと、「cidr_lookupをアセンブラで書き直せばもっと速くなるんでね?」と思い、インラインアセンブラで書き直してみたのがこちらのコードです。処理の内容は上記のものとまったく同じです。
そして、それぞれでベンチマークをとってみたところ、このような結果になりました。
※gccの最適化オプションは-O3を指定しました。
結果をみてびっくりしました。Cで書いた方が圧倒的に速かったのです。 アセンブラで書いたコードは、そのままでは64bit環境で動かないですし、メンテナンス性もわるいですし、コーディングにも時間がかかります。それでも高速に動作すれば使いどころはあるかなあと思っていましたが、今回は少し残念な結果に終わってしまいました。コンパイラの最適化ってすごいです!さすがです。
おわりに
今回の件で、単純にアセンブラで書き直しただけでは速くならないことを体感しました。しかし、このアセンブラのコードも、まだまだ高速化できる余地がいっぱい残っています。(というか、ツッコミどころ満載かもしれません(^^;;
どれだけコンパイラに近付くことができるかわかりませんが、もっと速くなるようにもう少しいじってみたいと思います。「この辺をこうしたら速くなるぞ」みたいなアドバイスをいただけると、大変ありがたいです。
ベンチマーク
5つのとあるIPアドレスのそれぞれについてどのグループに属するか判定する、という処理を5,000,000回実行したときの総所要時間(elapsed)と1回の判定に要した時間の平均(average)です。
ベンチマークは、同じハードウエアのx86とx86_64の環境の2つでとりました。
x86
name elapsed[sec] average[usec] ======================================== apr 59.379924 2.375197 ip-country 3.739187 0.149567 yasui-a 2.727045 0.109082 # インラインアセンブラ yasui-c 0.975544 0.039022 # C言語
x86_64
name elapsed[sec] average[usec] ======================================== apr 52.340651 2.093626 ip-country 0.664034 0.026561 yasui-c 0.706095 0.028244
コード
uint32_t half=0;
uint32_t csub=0;
uint32_t cmin=0;
uint32_t cmax=0;
uint32_t addr=0;
char *cidr_lookup(void *ipaddr, int len)
{
if(len != 4)
return(0);
addr = ntohl(*(uint32_t *)ipaddr);
asm volatile(
"lookup_start:"
"push %eax;"
"push %ebx;"
"push %ecx;"
"push %edx;"
"xor %eax, %eax;"
"movl %eax, cmin;"
"movl listcount, %eax;"
"movl %eax, cmax;"
"movl %eax, csub;"
"lookup_loop:"
"shr %eax;"
"add cmin, %eax;"
"movl %eax, half;"
"movl csub, %eax;"
"cmp $7, %eax;"
"jb lookup_search;"
"movl half, %eax;"
"shl $2, %eax;"
"movl cidrlist, %edx;"
"add %eax, %edx;"
"movl addr, %eax;"
"cmp (%edx), %eax;"
"jc lookup_low;"
"lookup_high:" /* addr > cidr */
"movl half, %eax;"
"movl %eax, cmin;"
"jmp lookup_endif;"
"lookup_low:" /* addr < cidr */
"movl half, %eax;"
"movl %eax, cmax;"
"lookup_endif:"
"movl cmax, %eax;"
"sub cmin, %eax;"
"movl %eax, csub;"
"jmp lookup_loop;"
"lookup_search:"
"movl cmax, %ecx;"
"sub cmin, %ecx;"
"movl cmax, %eax;"
"dec %eax;"
"movl %eax, half;"
"shl $2, %eax;"
"movl masklist, %ebx;"
"movl cidrlist, %edx;"
"add %eax, %ebx;"
"add %eax, %edx;"
"lookup_search_loop:"
"movl (%ebx), %eax;"
"and addr, %eax;"
"cmp (%edx), %eax;"
"je lookup_end;" /* if(addr == cidr) */
"sub $4, %ebx;"
"sub $4, %edx;"
"movl half, %eax;"
"dec %eax;"
"mov %eax, half;"
"loop lookup_search_loop;"
"movl $0xffffffff, %eax;"
"movl %eax, half;"
"lookup_end:"
"pop %edx;"
"pop %ecx;"
"pop %ebx;"
"pop %eax;"
);
if(half == -1)
return(NULL);
return(typelist[half]);
}
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <dirent.h>
#include <string.h>
#include <fcntl.h>
#include <libgen.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#define DBDIR "../data"
uint32_t listsize = 0;
uint32_t listcount = 0;
uint32_t *cidrlist = NULL;
uint32_t *masklist = NULL;
char **typelist = NULL;
int getnetmask(int n, uint32_t *netmask)
{
uint32_t h = 1;
uint32_t m = 0xffffffff;
if(n>32)
return(0);
n = 32 - n;
while(n--){
m -= h;
h *= 2;
}
*netmask = m;
return(1);
}
int initsort()
{
int i;
int j;
uint32_t t;
char *tt;
for(i=0;i cidrlist[j]){
t = cidrlist[i];
cidrlist[i] = cidrlist[j];
cidrlist[j] = t;
t = masklist[i];
masklist[i] = masklist[j];
masklist[j] = t;
tt = typelist[i];
typelist[i] = typelist[j];
typelist[j] = tt;
}
}
}
for(i=1;i< masklist[j]){
t = cidrlist[i];
cidrlist[i] = cidrlist[j];
cidrlist[j] = t;
t = masklist[i];
masklist[i] = masklist[j];
masklist[j] = t;
tt = typelist[i];
typelist[i] = typelist[j];
typelist[j] = tt;
}
}
}
return(0);
}
int initread(int fd, char *type)
{
int r;
int l;
char buff[256];
char *p = buff;
uint32_t ad=0;
uint32_t nm=0;
char ip[64];
while((r=read(fd,p,1)) != 0){
if(r == -1)
return(-1);
if(*p == '#')
*p = 0;
if(*p == '/')
*p = ' ';
if(*p == '\r')
*p = 0;
if(*p == '\n'){
*p = 0;
l = 32;
if(sscanf(buff,"%s%d", ip, &l) < 1){
/***** not ip addr ******/
if(buff[0]){
fprintf(stderr,"read error: %s(%s/%d)\n",buff,ip,l);
}
}else{
/*----- alloc -----*/
if(listcount == listsize){
listsize += 512;
typelist = (char **)realloc(typelist, listsize * sizeof(char *));
cidrlist = (uint32_t *)realloc(cidrlist, listsize * sizeof(uint32_t));
masklist = (uint32_t *)realloc(masklist, listsize * sizeof(uint32_t));
}
if(!inet_aton(ip, (struct in_addr *)&ad)){
fprintf(stderr,"ip addr error: %s\n", buff);
}
if(!getnetmask(l,&nm)){
fprintf(stderr,"netmask error: %s\n", buff);
}
ad = ntohl(ad);
ad &= nm;
typelist[listcount] = type;
cidrlist[listcount] = ad;
masklist[listcount] = nm;
listcount++;
}
p = buff;
continue;
}
p++;
}
return(0);
}
char *alloctype(char *name)
{
char *type;
char *base;
char path[PATH_MAX];
strcpy(path, name);
base=basename(path);
type=malloc(strlen(base)+1);
strcpy(type,base);
return(type);
}
void cidr_initialize()
{
int f;
DIR *d;
char path[PATH_MAX];
struct dirent *e;
struct stat st;
if((d=opendir(DBDIR)) != 0){
while((e=readdir(d)) != 0){
sprintf(path,"%s/%s", DBDIR, e->d_name);
if(stat(path, &st) != 0)
continue;
if(!S_ISREG(st.st_mode))
continue;
f = open(path, O_RDONLY);
if(f == -1){
fprintf(stderr,"file open error %s\n", path);
}else{
initread(f,alloctype(path));
close(f);
}
}
}
initsort();
}
