2008年08月01日

社内コードコンペ - お題:最速なCIDRブロックマッチ判定 〜 稲田の場合: hamanoが倒せない 〜

はてなブックマークに登録

おさらい


このコードのウリ

安井さんが2分探索で実装しているという話を聞いて、「それ、TRIE(トライ)で書いた方が速いしシンプルに 書けるんじゃね?」と思って、コードコンペに参加しました。

TRIEそのもの解説は、先日の濱野さんの物と同じなので省略します。

2分探索等だとO(log n) (nは登録されているcidrの数)の計算量になりますが、TRIEを使うと計算量はO(m) (mはアドレスの長さ) となり、登録するcidrの数が増えてもほとんど遅くなりません。 また、2分探索に比べると、探索部分のコードが非常にシンプルになるのもTRIEの魅力です。

基本的に濱野さんと比じ処理なのですが、性能で若干負ける代わりに、可読性や柔軟性はこちらの方が 高いと自負しています。

コード解説

まず、最初のバージョンはこんな感じになっていました。この頃は「アドレス帯がかぶっている時は 狭いアドレス帯を優先する」という要求がなかったこともあり、非常にシンプルです。

/* データ構造 */
typedef struct ADR_TRIE {
  const char *type;
  struct ADR_TRIE *child[TABLE_SIZE];
ADR_TRIE;

/* 葉からはchildを削ってメモリ節約する. */
typedef struct ADR_TRIE_LEAF {
  const char *type;
} ADR_TRIE_LEAF;
...

    /* 探索部分 */
    ADR_TRIE *pt = &trie_root;
    while (pt && (!pt->type)) {
      int b = addr >> 24;  /* このころはアドレスはuint32_tだった */
      pt = pt->child[b];
      addr <<= 8;
    }
    if (pt) return pt->type;

このコードを実装している間、IRCでは濱野さんが同じアイデアを提案していました。
hamano``> 256 分木つくって 1オクテットずつ読んで分木すれば、最低4回の遷移で判別出来る
hamano``> メモリ空間が大きくなるかもしれませんが、この程度のデータ量ならそれほど大きくならないと思います
hamano``> 僅か数命令で判別出来るのでこれ以上の最適化は無いかも
YasuiML> やってみて★ 
hamano``> あんまり、差が出ないのでイマイチ乗れないなぁ
hamano``> 問題を国判別に拡張しません?^^; 

そして、上の実装ができて、コードを安井さんに渡しました。

YasuiML> こみったす
katsumiD> 2段のテーブル引きにしたらどだろう
katsumiD> 上位16bit と 下位16bit にわけて
YasuiML> うほ
YasuiML> yasui-m@sag15:~$ ./cidrlookup 5000000 210.153.84.128 210.169.176.128 222.7.56.128 209.85.238.120 60.32.85.216
YasuiML> loop  : 5000000
YasuiML> elapse: 6.8481210
YasuiML> avg   : 0.0000003
YasuiML> 稲田さんのはやいすね!!
YasuiML> 一秒以上差がつきましたかあ

と褒められて、良い気になっていました。(この時点では、ベンチマークの内容が違い、経過時間がまったく異なります)

しかし、次の日、濱野さんがより高速なコードをコミットしました。その中身を読んでコードを読んで衝撃が走りました。

return type[tab[addr[3]][tab[addr[2]][tab[addr[1]][tab[addr[0]][1]]]]];

なんだこの[]だらけのコードは!分岐なしか!テーブル参照回数が同じである以上、 分岐を消さないと絶対勝てない!と思い、慌てて対抗して分岐を削除してみました。

  pt = pt->child[*addr++]; /* addrは、ネットワークバイトオーダーで格納されたアドレスの先頭アドレス */
  pt = pt->child[*addr++];
  pt = pt->child[*addr++];
  pt = pt->child[*addr++];
  return pt->type;

分岐を削除するためのhackとして、葉の形を変えました。分岐削除前は、pt->childを持たないようにしてメモリ消費を 抑えていたのですが、葉もpt->childを持ち、pt->child[n] == ptとしておくことで、短いサブネットアドレスでも4回の テーブルルックアップを実行するようにしました。分岐が減る代わりに、サブネットアドレスが短い場合はテーブルルック アップが増えるのですが、今回はマイクロベンチだからテーブルはほぼ確実にキャッシュに載っていますし、テストに使った アドレスもサブネットアドレスが長い物ばかりだったので、分岐を削除することでかなり速度を稼げました。

おわりに

ここに載せている以外にもいろいろと寄り道したのですが、最終的に、これ以上にできそうなアイデアは全て等価なものが 濱野さんのコードで実現済みという状況になってしまい、負けを認めました。 でも、メモリの確保の仕方が柔軟に対応できる点や、コードを読んで構造を把握しやすいかなどの面で、本採用を狙っています。

ベンチマーク

5つのとあるIPアドレスのそれぞれについてどのグループに属するか判定する、という処理を5,000,000回実行したときの総所要時間(elapsed)と1回の判定に要した時間の平均(average)です。

ベンチマークは、同じハードウエアのx86とx86_64の環境の2つでとりました。_8が、1テーブルでアドレス1byte分処理するバージョンで、 _16が1テーブルでアドレス2byte分処理するバージョンになります。16の方がテーブル参照が2回で済むので高速ですが、 メモリ使用量がバカみたいに増えるので、普通に使うなら8の方になります。

x64環境に置いてはかなり善戦していますが、hamanoさんにはあと一歩届きませんorz

x86

時間
name        elapsed[sec]  average[usec]
========================================
apr        59.379924      2.375197
ip-country  3.739187      0.149567
yasui-a     2.727045      0.109082   # インラインアセンブラ
yasui-c     0.975544      0.039022   # C言語
hamano-1    0.234664      0.009386
hamano-2    0.142496      0.005700
inada-n_16  0.175591      0.007023
inada-n_8   0.208570      0.008342

x86_64

name        elapsed[sec]  average[usec]
========================================
apr        52.340651      2.093626
ip-country  0.664034      0.026561
yasui-c     0.706095      0.028244
hamano-2    0.107557      0.004302
inada-n_16  0.116348      0.004653
inada-n_8   0.137042      0.005481

コード

#include <stdio.h>
#include <stdlib.h>

#include <unistd.h>
#include <stdint.h>
#include <dirent.h>
#include <string.h>
#include <fcntl.h>
#include <ctype.h>

#include <libgen.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#ifndef DBDIR
#  define DBDIR "../data"
#endif

#ifndef CIDR_TABLE_BITS
#  define CIDR_TABLE_BITS 8   /* 8 or 16: 16 is bit faster but use mooooooore memory. */
#endif
#define CIDR_TABLE_SIZE (1 << CIDR_TABLE_BITS)

/* データ構造について
 *
 * 葉以外の節
 *  node->child[n] はすべてのnにおいて別の節へのポインタになっている
 * 葉
 *  node->child[n] はすべてのnにおいてnodeになっている.
 */

typedef struct CIDR_TRIE {
  const char *name;
  uint32_t addr;
  uint32_t bits;
  struct CIDR_TRIE *child[CIDR_TABLE_SIZE];
} CIDR_TRIE;

static CIDR_TRIE trie_root;

static int getnetmask(int n, uint32_t *netmask)
{
  uint32_t m;
  if (n < 0 || 32 < n) return 0;

  m = 1UL << (32 - n);
  --m;
  *netmask = ~m;
  return 1;
}

static int is_leaf(const CIDR_TRIE *pt) {
  return pt->child[0] == pt;
}

static CIDR_TRIE* new_trie_node() {
  int i;
  CIDR_TRIE *pt = malloc(sizeof(CIDR_TRIE));
  pt->name = 0;
  pt->addr = 0;
  pt->bits = 0;
  for (i = 0; i < CIDR_TABLE_SIZE; ++i)
    pt->child[i] = pt;
  return pt;
}

static void init_root()
{
  int i;
  CIDR_TRIE *nullnode;
  nullnode = new_trie_node();
  for (i = 0; i < CIDR_TABLE_SIZE; ++i)
    trie_root.child[i] = nullnode;
}

/** 子node 'child' から親ノードを作る(親ノードの子はすべてchildになる) */
static CIDR_TRIE* digg_trie(CIDR_TRIE *child) {
  int i;
  CIDR_TRIE *parent = new_trie_node();
  for (i = 0; i < CIDR_TABLE_SIZE; ++i)
    parent->child[i] = child;
  return parent;
}

/* Trieを巡回して、与えられたleafよりもbitsが短い葉を全部leafで置き換える */
static int update_leaf(CIDR_TRIE *pt, CIDR_TRIE *leaf)
{
  int used = 0;
  int i;
  for (i = 0; i < CIDR_TABLE_SIZE; ++i) {
    CIDR_TRIE *next = pt->child[i];
    if (is_leaf(next)) {
      if (next->bits < leaf->bits) {
        pt->child[i] = leaf;
        used = 1;
      }
    }
    else {
      used |= update_leaf(next, leaf);
    }
  }
  return used;
}

int initread(int fd, char *type)
{
  int  r;
  int  l;
  char buff[256];
  char *p = buff;
  uint32_t ad;
  uint32_t nm = 0xffffffff;
  char ip[64];

  while((r=read(fd,p,1)) != 0){
    if(r == -1)
      return(-1);
    if(*p == '#')
      *p = 0;
    if(*p == '/')
      *p = ' ';
    if(*p == '\r')
      *p = 0;
    if(*p == '\n'){
      *p = 0;
      l = 32;
      if(sscanf(buff,"%s%d", ip, &l) < 1){
        /***** not ip addr ******/
        if(buff[0]){
          fprintf(stderr,"read error: %s(%s/%d)\n",buff,ip,l);
        }
      }else{
        if(!inet_aton(ip, (struct in_addr *)&ad)){
          fprintf(stderr,"ip addr error: %s\n", buff);
        }
        if(!getnetmask(l,&nm)){
          fprintf(stderr,"netmask error: %s\n", buff);
        }
        //printf("adding: %s %x %d\n", type, ad, nm);

        ad = ntohl(ad);
        ad &= nm;

        {
          CIDR_TRIE *pt = &trie_root;
          CIDR_TRIE *p_leaf = new_trie_node();

          p_leaf->name = type;
          p_leaf->addr = ad;
          p_leaf->bits = l;

          while (l > CIDR_TABLE_BITS) {
            int b = ad >> (32 - CIDR_TABLE_BITS);
            CIDR_TRIE *next = pt->child[b];
            if (is_leaf(next)) {
              pt->child[b] = next = digg_trie(next);
            }
            pt = next;
            ad <<= CIDR_TABLE_BITS;
            l -= CIDR_TABLE_BITS;
          }
          {
            int i;
            const int bmin = ad >> (32 - CIDR_TABLE_BITS);
            const int bmax = bmin + (1 << (CIDR_TABLE_BITS - l));
            int used = 0; // delete p_leaf if it is not used.
            for (i = bmin; i < bmax; ++i) {
              CIDR_TRIE *target = pt->child[i];
              if (is_leaf(target)) {
                if (target->bits < p_leaf->bits) {
                  pt->child[i] = p_leaf;
                  used = 1;
                }
              }
              else {
                int j;
                for (j = 0; j < CIDR_TABLE_SIZE; ++j) {
                  used |= update_leaf(target, p_leaf);
                }
              }
            }
            if (!used) {
              free(p_leaf);
            }
          }
        }
      }
      p = buff;
      continue;
    }
    p++;
  }
  return(0);  
}

#if 0
/* for debug. */
static void dump_trie(const CIDR_TRIE *pt, int indent) {
  int i, j;
  if (is_leaf(pt)) {
    for (j = 0; j < indent; ++j) putchar(' '); 
    printf("%s (%d/%d)\n", pt->name, pt->addr, pt->bits);
    return;
  }
  for (i = 0; i < CIDR_TABLE_SIZE; ++i) {
    CIDR_TRIE *next = pt->child[i];
    for (j = 0; j < indent; ++j) putchar(' '); 
    printf("child[%d]: %p", i, next);
    if (is_leaf(next)) {
      printf(" %s (%d/%d)\n", next->name, next->addr, next->bits);
    } else {
      putchar('\n');
      dump_trie(next, indent+2);
    }
  }
}
#endif

char *alloctype(char *name)
{
  char *base;
  char path[PATH_MAX];
  strcpy(path, name);
  base=basename(path);
  return strdup(base);
}

/*----- 初期化 -----*/
void cidr_initialize()
{
  int  f;
  char path[PATH_MAX];
  struct stat   st;
  struct dirent **namelist;
  int n;

  init_root();

  n = scandir(DBDIR, &namelist, 0, alphasort);
  if (n < 0)
    perror("scandir");
  else {
    int i;
    for (i=0; i<n; i++) {
      sprintf(path,"%s/%s", DBDIR, namelist[i]->d_name);
      //----- レギュラーファイル以外は見ない -----
      if(stat(path, &st) != 0) {
        free(namelist[i]);
        continue;
      }
      if(!S_ISREG(st.st_mode)) {
        free(namelist[i]);
        continue;
      }
      f = open(path, O_RDONLY);
      if(f == -1){
        fprintf(stderr,"file open error %s\n", path);
      }else{
        //printf("reading: %s\n", path);
        initread(f,alloctype(path));
        close(f);
      }
      free(namelist[i]);
    }
    free(namelist);
  }
  //dump_trie(&trie_root, 0);
}

const char *cidr_lookup(const uint8_t *addr, int len)
{
  CIDR_TRIE *pt = &trie_root;

#if CIDR_TABLE_BITS == 8
  pt = pt->child[*addr++];
  pt = pt->child[*addr++];
  pt = pt->child[*addr++];
  pt = pt->child[*addr++];
#elif CIDR_TABLE_BITS == 16
  pt = pt->child[addr[0] * 256 + addr[1]];
  pt = pt->child[addr[2] * 256 + addr[3]];
#else
#error CIDR_TABLE_BITS must be 8 or 16.
#endif

  return pt->name;
}

/* vim: set shiftwidth=2: */

klab_gijutsu2 at 08:00│Comments(0)TrackBack(0)開発 | codecompe

トラックバックURL

この記事にコメントする

名前:
URL:
  情報を記憶: 評価: 顔   
 
 
 
Blog内検索
Archives
このブログについて
DSASとは、KLab が構築し運用しているコンテンツサービス用のLinuxベースのインフラです。現在5ヶ所のデータセンタにて構築し、運用していますが、我々はDSASをより使いやすく、より安全に、そしてより省力で運用できることを目指して、日々改良に勤しんでいます。
このブログでは、そんな DSAS で使っている技術の紹介や、実験してみた結果の報告、トラブルに巻き込まれた時の経験談など、広く深く、色々な話題を織りまぜて紹介していきたいと思います。
最新コメント