root/lang/cplusplus/range_coder/range_coder.hpp

Revision 7158, 4.3 kB (checked in by kazuho, 11 months ago)

fix overrun

Line 
1#ifndef __RANGE_CODER_HPP__
2#define __RANGE_CODER_HPP__
3
4#ifdef RANGE_CODER_USE_SSE
5#include <xmmintrin.h>
6#endif
7
8// original work by Daisuke Okanohara 2006/06/16
9
10struct rc_type_t {
11  enum {
12    TOP = 1U << 24,
13    TOPMASK = TOP - 1,
14  };
15  typedef unsigned int uint;
16  typedef unsigned char byte;
17};
18
19template <class Iter> class rc_encoder_t : public rc_type_t {
20public:
21  rc_encoder_t(const Iter &i) : iter(i) {
22    L      = 0;
23    R      = 0xFFFFFFFF;
24    buffer = 0;
25    carryN = 0;
26    counter = 0;
27    start  = true;
28  }
29  void encode(const uint low, const uint high, const uint total) {
30    uint r    = R / total;
31    if (high < total) {
32      R = r * (high-low);
33    } else {
34      R -= r * low;
35    }
36    uint newL = L + r*low;
37    if (newL < L) {
38      //overflow occured (newL >= 2^32)
39      //buffer FF FF .. FF -> buffer+1 00 00 .. 00
40      buffer++;
41      for (;carryN > 0; carryN--) {
42        *iter++ = buffer;
43        buffer = 0;
44      }
45    }
46    L = newL;
47    while (R < TOP) {
48      const byte newBuffer = (L >> 24);
49      if (start) {
50        buffer = newBuffer;
51        start  = false;
52      } else if (newBuffer == 0xFF) {
53        carryN++;
54      } else {
55        *iter++ = buffer;
56        for (; carryN != 0; carryN--) {
57          *iter++ = 0xFF;
58        }
59        buffer = newBuffer;
60      }
61      L <<= 8;
62      R <<= 8;
63    }
64    counter++;
65  }
66  void final() {
67    *iter++ = buffer;
68    for (; carryN != 0; carryN--) {
69      *iter++ = 0xFF;
70    }
71    uint t = L + R;
72    while (1) {
73      uint t8 = t >> 24, l8 = L >> 24;
74      *iter++ = l8;
75      if (t8 != l8) {
76        break;
77      }
78      t <<= 8;
79      L <<= 8;
80    }
81  }
82private:
83  uint R;
84  uint L;
85  bool start;
86  byte buffer;
87  uint carryN;
88  Iter iter;
89  uint counter;
90};
91
92template <typename FreqType, unsigned _N, int _BASE> struct rc_decoder_search_traits_t : public rc_type_t {
93  typedef FreqType freq_type;
94  enum {
95    N = _N,
96    BASE = _BASE
97  };
98};
99
100template <typename FreqType, unsigned _N, int _BASE = 0> struct rc_decoder_search_t : public rc_decoder_search_traits_t<FreqType, _N, _BASE> {
101  static uint get_index(const FreqType *freq, FreqType pos) {
102    uint left  = 0;
103    uint right = _N;
104    while(left < right) {
105      uint mid = (left+right)/2;
106      if (freq[mid+1] <= pos) left = mid+1;
107      else                    right = mid;
108    }
109    return left;
110  }
111};
112
113#ifdef RANGE_CODER_USE_SSE
114
115template<int _BASE> struct rc_decoder_search_t<short, 256, _BASE> : public rc_decoder_search_traits_t<short, 256, _BASE> {
116  static uint get_index(const short *freq, short pos) {
117    __m128i v = _mm_set1_epi16(pos);
118    unsigned i, mask = 0;
119    for (i = 0; i < 256; i += 16) {
120      __m128i x = *reinterpret_cast<const __m128i*>(freq + i);
121      __m128i y = *reinterpret_cast<const __m128i*>(freq + i + 8);
122      __m128i a = _mm_cmplt_epi16(v, x);
123      __m128i b = _mm_cmplt_epi16(v, y);
124      mask = (_mm_movemask_epi8(b) << 16) | _mm_movemask_epi8(a);
125      if (mask) {
126        return i + (__builtin_ctz(mask) >> 1) - 1;
127      }
128    }
129    return 255;
130  }
131};
132
133#endif
134
135template <class Iterator, class SearchType> class rc_decoder_t : public rc_type_t {
136public:
137  typedef SearchType search_type;
138  typedef typename search_type::freq_type freq_type;
139  static const unsigned N = search_type::N;
140  rc_decoder_t(const Iterator& _i, const Iterator _e) : iter(_i), iter_end(_e) {
141    R = 0xFFFFFFFF;
142    D = 0;
143    for (int i = 0; i < 4; i++) {
144      D = (D << 8) | next();
145    }
146  }
147  uint decode(const uint total, const freq_type* cumFreq) {
148    const uint r = R / total;
149    const int targetPos = std::min(total-1, D / r);
150   
151    //find target s.t. cumFreq[target] <= targetPos < cumFreq[target+1]
152    const uint target =
153      search_type::get_index(cumFreq, targetPos + search_type::BASE);
154    const uint low  = cumFreq[target] - search_type::BASE;
155    const uint high = cumFreq[target+1] - search_type::BASE;
156   
157    D -= r * low;
158    if (high != total) {
159      R = r * (high-low);
160    } else {
161      R -= r * low;
162    }
163   
164    while (R < TOP) {
165      R <<= 8;
166      D = (D << 8) | next();
167    }
168   
169    return target;
170  }
171  byte next() {
172    return iter != iter_end ? (byte)*iter++ : 0xff;
173  }
174private:
175  uint R;
176  uint D;
177  Iterator iter, iter_end;
178};
179
180#endif
Note: See TracBrowser for help on using the browser.