File Coverage

src/simd/mds_simd_avx2.c
Criterion Covered Total %
statement 67 71 94.3
branch 33 36 91.6
condition n/a
subroutine n/a
pod n/a
total 100 107 93.4


line stmt bran cond sub pod time code
1             /* src/simd/mds_simd_avx2.c — AVX2 classifier.
2             *
3             * Processes 32 bytes per iteration. Each function carries
4             * __attribute__((target("avx2,bmi2"))) so the TU compiles even when the
5             * baseline does not enable AVX2. Runtime dispatch (CPUID in
6             * mds_simd_dispatch.c) only routes calls here on capable CPUs.
7             *
8             * MSVC has no per-function target attribute; for MSVC, the build system
9             * only adds this file when /arch:AVX2 is in effect.
10             */
11             #include "mds_simd.h"
12             #include "mds_classifier_lut.h"
13              
14             #ifdef MDS_HAVE_AVX2
15              
16             #include
17              
18             #if defined(__GNUC__) || defined(__clang__)
19             # define MDS_AVX2_FN __attribute__((target("avx2,bmi2")))
20             #else
21             # define MDS_AVX2_FN
22             #endif
23              
24 598           MDS_AVX2_FN static void classify_structural_avx2(const char* in, size_t len,
25             uint64_t* out)
26             {
27             /* Broadcast the 16-byte LUTs into both 128-bit lanes; _mm256_shuffle_epi8
28             * is per-lane. */
29 598           __m128i lo_tbl128 = _mm_loadu_si128((const __m128i*)MDS_CLASSIFIER_LO);
30 598           __m128i hi_tbl128 = _mm_loadu_si128((const __m128i*)MDS_CLASSIFIER_HI);
31 598           __m256i lo_tbl = _mm256_broadcastsi128_si256(lo_tbl128);
32 598           __m256i hi_tbl = _mm256_broadcastsi128_si256(hi_tbl128);
33 598           __m256i mask_lo = _mm256_set1_epi8(0x0F);
34 598           __m256i zero = _mm256_setzero_si256();
35              
36 598           size_t i = 0;
37 16136 100         while (i + 32 <= len) {
38 31076           __m256i v = _mm256_loadu_si256((const __m256i*)(in + i));
39 15538           __m256i lo = _mm256_and_si256(v, mask_lo);
40             /* high nibble: srli_epi16 by 4 then mask 0x0F */
41 31076           __m256i hi = _mm256_and_si256(_mm256_srli_epi16(v, 4), mask_lo);
42 15538           __m256i la = _mm256_shuffle_epi8(lo_tbl, lo);
43 15538           __m256i ha = _mm256_shuffle_epi8(hi_tbl, hi);
44 15538           __m256i m = _mm256_and_si256(la, ha);
45             /* Use cmpeq(m,0) then invert: cmpgt_epi8 is SIGNED, which would
46             * misclassify bytes whose LUT product is 0x80 (e.g. '|' = 0x7C,
47             * '~' = 0x7E) — they hit hi_tbl[7]=0x80 and would read as -128. */
48 15538           __m256i is_zero = _mm256_cmpeq_epi8(m, zero); /* 0xFF where m==0 */
49 15538           uint32_t bits = (uint32_t)(~_mm256_movemask_epi8(is_zero));
50              
51 15538           size_t word = i >> 6;
52 15538           size_t off = i & 63u;
53 15538           out[word] |= (uint64_t)bits << off;
54 15538           i += 32;
55             }
56 9948 100         for (; i < len; i++) {
57 9350           uint8_t b = (uint8_t)in[i];
58 9350 100         if (MDS_CLASSIFIER_LO[b & 0xF] & MDS_CLASSIFIER_HI[b >> 4])
59 967           out[i >> 6] |= (uint64_t)1 << (i & 63);
60             }
61 598           }
62              
63             #ifdef s_scalar
64             # undef s_scalar
65             #endif
66 295           static const mds_simd_ops* s_scalar_avx2(void) { return mds_simd_ops_scalar(); }
67             #define s_scalar s_scalar_avx2
68              
69             /* ASCII fast-path validator. Non-ASCII chunks delegate to scalar DFA
70             * (after extending forward across any in-flight continuation bytes). */
71 283           MDS_AVX2_FN static int validate_utf8_avx2(const char* in, size_t len)
72             {
73 283           const unsigned char* p = (const unsigned char*)in;
74 283           const unsigned char* end = p + len;
75              
76 6967 100         while ((size_t)(end - p) >= 32) {
77 6883           __m256i v = _mm256_loadu_si256((const __m256i*)p);
78             /* movemask of v: bit i = (v[i] >> 7). If zero, all ASCII. */
79 6883           int mask = _mm256_movemask_epi8(v);
80 6883 100         if (mask == 0) { p += 32; continue; }
81              
82 212           const unsigned char* tail = p + 32;
83 212 50         if (tail > end) tail = end;
84 212           int extend = 3;
85 288 100         while (extend-- > 0 && tail < end && (*tail & 0xC0) == 0x80) tail++;
    50          
    100          
86 212 100         if (!s_scalar()->validate_utf8((const char*)p, (size_t)(tail - p)))
87 199           return 0;
88 13           p = tail;
89             }
90 84 100         if (p < end) return s_scalar()->validate_utf8((const char*)p, (size_t)(end - p));
91 1           return 1;
92             }
93              
94 2053           MDS_AVX2_FN static size_t find_newlines_avx2(const char* in, size_t len,
95             uint32_t* out, size_t cap)
96             {
97 2053           const char* p = in;
98 2053           const char* end = in + len;
99 2053           __m256i needle = _mm256_set1_epi8('\n');
100 2053           size_t k = 0;
101              
102 18339 100         while ((size_t)(end - p) >= 32) {
103 16286           __m256i v = _mm256_loadu_si256((const __m256i*)p);
104 16286           __m256i cmp = _mm256_cmpeq_epi8(v, needle);
105 16286           uint32_t m = (uint32_t)_mm256_movemask_epi8(cmp);
106 16286 100         if (m) {
107 12619           uint32_t base = (uint32_t)(p - in);
108             do {
109 41229           unsigned bit = (unsigned)__builtin_ctz(m);
110 41229 50         if (k >= cap) return (size_t)-1;
111 41229           out[k++] = base + bit;
112 41229           m &= m - 1; /* clear lowest set bit */
113 41229 100         } while (m);
114             }
115 16286           p += 32;
116             }
117 34002 100         while (p < end) {
118 31950 100         if (*p == '\n') {
119 3981 100         if (k >= cap) return (size_t)-1;
120 3980           out[k++] = (uint32_t)(p - in);
121             }
122 31949           p++;
123             }
124 2052           return k;
125             }
126              
127 0           MDS_AVX2_FN static const char* next_structural_avx2(const char* p,
128             const char* end)
129 0           { return s_scalar()->next_structural(p, end); }
130              
131 0           MDS_AVX2_FN static const char* next_structural_bm_avx2(const char* base,
132             size_t bm_len,
133             const uint64_t* bm,
134             size_t p_off)
135 0           { return s_scalar()->next_structural_bm(base, bm_len, bm, p_off); }
136              
137             static const mds_simd_ops k_ops_avx2 = {
138             classify_structural_avx2,
139             validate_utf8_avx2,
140             find_newlines_avx2,
141             next_structural_bm_avx2,
142             next_structural_avx2,
143             };
144              
145 731           const mds_simd_ops* mds_simd_ops_avx2(void) { return &k_ops_avx2; }
146              
147             #endif /* MDS_HAVE_AVX2 */