社内コードコンペ - お題:最速なCIDRブロックマッチ判定 〜 稲田の場合: hamanoが倒せない 〜
おさらい
- #1 ひろせの場合 - IP::CountryとAPRを使ってみた
- #2 安井の場合: バイナリサーチのあれとこれ
- #3 hamanoの場合: あ ありのまま 今 起こった事を話すぜ!『コードコンペだと思ったらゴルフコンペだった』な(ry
- #4 稲田の場合: hamanoが倒せない ← 今回
このコードのウリ
安井さんが2分探索で実装しているという話を聞いて、「それ、TRIE(トライ)で書いた方が速いしシンプルに 書けるんじゃね?」と思って、コードコンペに参加しました。
TRIEそのもの解説は、先日の濱野さんの物と同じなので省略します。
2分探索等だとO(log n) (nは登録されているcidrの数)の計算量になりますが、TRIEを使うと計算量はO(m) (mはアドレスの長さ) となり、登録するcidrの数が増えてもほとんど遅くなりません。 また、2分探索に比べると、探索部分のコードが非常にシンプルになるのもTRIEの魅力です。
基本的に濱野さんと比じ処理なのですが、性能で若干負ける代わりに、可読性や柔軟性はこちらの方が 高いと自負しています。
コード解説
まず、最初のバージョンはこんな感じになっていました。この頃は「アドレス帯がかぶっている時は 狭いアドレス帯を優先する」という要求がなかったこともあり、非常にシンプルです。
/* データ構造 */
typedef struct ADR_TRIE {
const char *type;
struct ADR_TRIE *child[TABLE_SIZE];
ADR_TRIE;
/* 葉からはchildを削ってメモリ節約する. */
typedef struct ADR_TRIE_LEAF {
const char *type;
} ADR_TRIE_LEAF;
...
/* 探索部分 */
ADR_TRIE *pt = &trie_root;
while (pt && (!pt->type)) {
int b = addr >> 24; /* このころはアドレスはuint32_tだった */
pt = pt->child[b];
addr <<= 8;
}
if (pt) return pt->type;
このコードを実装している間、IRCでは濱野さんが同じアイデアを提案していました。
hamano``> 256 分木つくって 1オクテットずつ読んで分木すれば、最低4回の遷移で判別出来る hamano``> メモリ空間が大きくなるかもしれませんが、この程度のデータ量ならそれほど大きくならないと思います hamano``> 僅か数命令で判別出来るのでこれ以上の最適化は無いかも YasuiML> やってみて★ hamano``> あんまり、差が出ないのでイマイチ乗れないなぁ hamano``> 問題を国判別に拡張しません?^^;
そして、上の実装ができて、コードを安井さんに渡しました。
YasuiML> こみったす katsumiD> 2段のテーブル引きにしたらどだろう katsumiD> 上位16bit と 下位16bit にわけて YasuiML> うほ YasuiML> yasui-m@sag15:~$ ./cidrlookup 5000000 210.153.84.128 210.169.176.128 222.7.56.128 209.85.238.120 60.32.85.216 YasuiML> loop : 5000000 YasuiML> elapse: 6.8481210 YasuiML> avg : 0.0000003 YasuiML> 稲田さんのはやいすね!! YasuiML> 一秒以上差がつきましたかあ
と褒められて、良い気になっていました。(この時点では、ベンチマークの内容が違い、経過時間がまったく異なります)
しかし、次の日、濱野さんがより高速なコードをコミットしました。その中身を読んでコードを読んで衝撃が走りました。
return type[tab[addr[3]][tab[addr[2]][tab[addr[1]][tab[addr[0]][1]]]]];
なんだこの[]だらけのコードは!分岐なしか!テーブル参照回数が同じである以上、 分岐を消さないと絶対勝てない!と思い、慌てて対抗して分岐を削除してみました。
pt = pt->child[*addr++]; /* addrは、ネットワークバイトオーダーで格納されたアドレスの先頭アドレス */ pt = pt->child[*addr++]; pt = pt->child[*addr++]; pt = pt->child[*addr++]; return pt->type;
分岐を削除するためのhackとして、葉の形を変えました。分岐削除前は、pt->childを持たないようにしてメモリ消費を 抑えていたのですが、葉もpt->childを持ち、pt->child[n] == ptとしておくことで、短いサブネットアドレスでも4回の テーブルルックアップを実行するようにしました。分岐が減る代わりに、サブネットアドレスが短い場合はテーブルルック アップが増えるのですが、今回はマイクロベンチだからテーブルはほぼ確実にキャッシュに載っていますし、テストに使った アドレスもサブネットアドレスが長い物ばかりだったので、分岐を削除することでかなり速度を稼げました。
おわりに
ここに載せている以外にもいろいろと寄り道したのですが、最終的に、これ以上にできそうなアイデアは全て等価なものが 濱野さんのコードで実現済みという状況になってしまい、負けを認めました。 でも、メモリの確保の仕方が柔軟に対応できる点や、コードを読んで構造を把握しやすいかなどの面で、本採用を狙っています。
ベンチマーク
5つのとあるIPアドレスのそれぞれについてどのグループに属するか判定する、という処理を5,000,000回実行したときの総所要時間(elapsed)と1回の判定に要した時間の平均(average)です。
ベンチマークは、同じハードウエアのx86とx86_64の環境の2つでとりました。_8が、1テーブルでアドレス1byte分処理するバージョンで、 _16が1テーブルでアドレス2byte分処理するバージョンになります。16の方がテーブル参照が2回で済むので高速ですが、 メモリ使用量がバカみたいに増えるので、普通に使うなら8の方になります。
x64環境に置いてはかなり善戦していますが、hamanoさんにはあと一歩届きませんorz
x86
時間name elapsed[sec] average[usec] ======================================== apr 59.379924 2.375197 ip-country 3.739187 0.149567 yasui-a 2.727045 0.109082 # インラインアセンブラ yasui-c 0.975544 0.039022 # C言語 hamano-1 0.234664 0.009386 hamano-2 0.142496 0.005700 inada-n_16 0.175591 0.007023 inada-n_8 0.208570 0.008342
x86_64
name elapsed[sec] average[usec] ======================================== apr 52.340651 2.093626 ip-country 0.664034 0.026561 yasui-c 0.706095 0.028244 hamano-2 0.107557 0.004302 inada-n_16 0.116348 0.004653 inada-n_8 0.137042 0.005481
コード
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <stdint.h>
#include <dirent.h>
#include <string.h>
#include <fcntl.h>
#include <ctype.h>
#include <libgen.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#ifndef DBDIR
# define DBDIR "../data"
#endif
#ifndef CIDR_TABLE_BITS
# define CIDR_TABLE_BITS 8 /* 8 or 16: 16 is bit faster but use mooooooore memory. */
#endif
#define CIDR_TABLE_SIZE (1 << CIDR_TABLE_BITS)
/* データ構造について
*
* 葉以外の節
* node->child[n] はすべてのnにおいて別の節へのポインタになっている
* 葉
* node->child[n] はすべてのnにおいてnodeになっている.
*/
typedef struct CIDR_TRIE {
const char *name;
uint32_t addr;
uint32_t bits;
struct CIDR_TRIE *child[CIDR_TABLE_SIZE];
} CIDR_TRIE;
static CIDR_TRIE trie_root;
static int getnetmask(int n, uint32_t *netmask)
{
uint32_t m;
if (n < 0 || 32 < n) return 0;
m = 1UL << (32 - n);
--m;
*netmask = ~m;
return 1;
}
static int is_leaf(const CIDR_TRIE *pt) {
return pt->child[0] == pt;
}
static CIDR_TRIE* new_trie_node() {
int i;
CIDR_TRIE *pt = malloc(sizeof(CIDR_TRIE));
pt->name = 0;
pt->addr = 0;
pt->bits = 0;
for (i = 0; i < CIDR_TABLE_SIZE; ++i)
pt->child[i] = pt;
return pt;
}
static void init_root()
{
int i;
CIDR_TRIE *nullnode;
nullnode = new_trie_node();
for (i = 0; i < CIDR_TABLE_SIZE; ++i)
trie_root.child[i] = nullnode;
}
/** 子node 'child' から親ノードを作る(親ノードの子はすべてchildになる) */
static CIDR_TRIE* digg_trie(CIDR_TRIE *child) {
int i;
CIDR_TRIE *parent = new_trie_node();
for (i = 0; i < CIDR_TABLE_SIZE; ++i)
parent->child[i] = child;
return parent;
}
/* Trieを巡回して、与えられたleafよりもbitsが短い葉を全部leafで置き換える */
static int update_leaf(CIDR_TRIE *pt, CIDR_TRIE *leaf)
{
int used = 0;
int i;
for (i = 0; i < CIDR_TABLE_SIZE; ++i) {
CIDR_TRIE *next = pt->child[i];
if (is_leaf(next)) {
if (next->bits < leaf->bits) {
pt->child[i] = leaf;
used = 1;
}
}
else {
used |= update_leaf(next, leaf);
}
}
return used;
}
int initread(int fd, char *type)
{
int r;
int l;
char buff[256];
char *p = buff;
uint32_t ad;
uint32_t nm = 0xffffffff;
char ip[64];
while((r=read(fd,p,1)) != 0){
if(r == -1)
return(-1);
if(*p == '#')
*p = 0;
if(*p == '/')
*p = ' ';
if(*p == '\r')
*p = 0;
if(*p == '\n'){
*p = 0;
l = 32;
if(sscanf(buff,"%s%d", ip, &l) < 1){
/***** not ip addr ******/
if(buff[0]){
fprintf(stderr,"read error: %s(%s/%d)\n",buff,ip,l);
}
}else{
if(!inet_aton(ip, (struct in_addr *)&ad)){
fprintf(stderr,"ip addr error: %s\n", buff);
}
if(!getnetmask(l,&nm)){
fprintf(stderr,"netmask error: %s\n", buff);
}
//printf("adding: %s %x %d\n", type, ad, nm);
ad = ntohl(ad);
ad &= nm;
{
CIDR_TRIE *pt = &trie_root;
CIDR_TRIE *p_leaf = new_trie_node();
p_leaf->name = type;
p_leaf->addr = ad;
p_leaf->bits = l;
while (l > CIDR_TABLE_BITS) {
int b = ad >> (32 - CIDR_TABLE_BITS);
CIDR_TRIE *next = pt->child[b];
if (is_leaf(next)) {
pt->child[b] = next = digg_trie(next);
}
pt = next;
ad <<= CIDR_TABLE_BITS;
l -= CIDR_TABLE_BITS;
}
{
int i;
const int bmin = ad >> (32 - CIDR_TABLE_BITS);
const int bmax = bmin + (1 << (CIDR_TABLE_BITS - l));
int used = 0; // delete p_leaf if it is not used.
for (i = bmin; i < bmax; ++i) {
CIDR_TRIE *target = pt->child[i];
if (is_leaf(target)) {
if (target->bits < p_leaf->bits) {
pt->child[i] = p_leaf;
used = 1;
}
}
else {
int j;
for (j = 0; j < CIDR_TABLE_SIZE; ++j) {
used |= update_leaf(target, p_leaf);
}
}
}
if (!used) {
free(p_leaf);
}
}
}
}
p = buff;
continue;
}
p++;
}
return(0);
}
#if 0
/* for debug. */
static void dump_trie(const CIDR_TRIE *pt, int indent) {
int i, j;
if (is_leaf(pt)) {
for (j = 0; j < indent; ++j) putchar(' ');
printf("%s (%d/%d)\n", pt->name, pt->addr, pt->bits);
return;
}
for (i = 0; i < CIDR_TABLE_SIZE; ++i) {
CIDR_TRIE *next = pt->child[i];
for (j = 0; j < indent; ++j) putchar(' ');
printf("child[%d]: %p", i, next);
if (is_leaf(next)) {
printf(" %s (%d/%d)\n", next->name, next->addr, next->bits);
} else {
putchar('\n');
dump_trie(next, indent+2);
}
}
}
#endif
char *alloctype(char *name)
{
char *base;
char path[PATH_MAX];
strcpy(path, name);
base=basename(path);
return strdup(base);
}
/*----- 初期化 -----*/
void cidr_initialize()
{
int f;
char path[PATH_MAX];
struct stat st;
struct dirent **namelist;
int n;
init_root();
n = scandir(DBDIR, &namelist, 0, alphasort);
if (n < 0)
perror("scandir");
else {
int i;
for (i=0; i<n; i++) {
sprintf(path,"%s/%s", DBDIR, namelist[i]->d_name);
//----- レギュラーファイル以外は見ない -----
if(stat(path, &st) != 0) {
free(namelist[i]);
continue;
}
if(!S_ISREG(st.st_mode)) {
free(namelist[i]);
continue;
}
f = open(path, O_RDONLY);
if(f == -1){
fprintf(stderr,"file open error %s\n", path);
}else{
//printf("reading: %s\n", path);
initread(f,alloctype(path));
close(f);
}
free(namelist[i]);
}
free(namelist);
}
//dump_trie(&trie_root, 0);
}
const char *cidr_lookup(const uint8_t *addr, int len)
{
CIDR_TRIE *pt = &trie_root;
#if CIDR_TABLE_BITS == 8
pt = pt->child[*addr++];
pt = pt->child[*addr++];
pt = pt->child[*addr++];
pt = pt->child[*addr++];
#elif CIDR_TABLE_BITS == 16
pt = pt->child[addr[0] * 256 + addr[1]];
pt = pt->child[addr[2] * 256 + addr[3]];
#else
#error CIDR_TABLE_BITS must be 8 or 16.
#endif
return pt->name;
}
/* vim: set shiftwidth=2: */
