OpenJPH
Open-source implementation of JPEG2000 Part-15
Loading...
Searching...
No Matches
ojph_transform_avx512.cpp
Go to the documentation of this file.
1//***************************************************************************/
2// This software is released under the 2-Clause BSD license, included
3// below.
4//
5// Copyright (c) 2019-2024, Aous Naman
6// Copyright (c) 2019-2024, Kakadu Software Pty Ltd, Australia
7// Copyright (c) 2019-2024, The University of New South Wales, Australia
8//
9// Redistribution and use in source and binary forms, with or without
10// modification, are permitted provided that the following conditions are
11// met:
12//
13// 1. Redistributions of source code must retain the above copyright
14// notice, this list of conditions and the following disclaimer.
15//
16// 2. Redistributions in binary form must reproduce the above copyright
17// notice, this list of conditions and the following disclaimer in the
18// documentation and/or other materials provided with the distribution.
19//
20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
21// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22// TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
23// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
26// TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31//***************************************************************************/
32// This file is part of the OpenJPH software implementation.
33// File: ojph_transform_avx512.cpp
34// Author: Aous Naman
35// Date: 13 April 2024
36//***************************************************************************/
37
38#include <cstdio>
39
40#include "ojph_defs.h"
41#include "ojph_arch.h"
42#include "ojph_mem.h"
43#include "ojph_params.h"
45
46#include "ojph_transform.h"
48
49#include <immintrin.h>
50
51namespace ojph {
52 namespace local {
53
55 // We split multiples of 32 followed by multiples of 16, because
56 // we assume byte_alignment == 64
57 static
58 void avx512_deinterleave32(float* dpl, float* dph, float* sp, int width)
59 {
60 __m512i idx1 = _mm512_set_epi32(
61 0x1E, 0x1C, 0x1A, 0x18, 0x16, 0x14, 0x12, 0x10,
62 0x0E, 0x0C, 0x0A, 0x08, 0x06, 0x04, 0x02, 0x00
63 );
64 __m512i idx2 = _mm512_set_epi32(
65 0x1F, 0x1D, 0x1B, 0x19, 0x17, 0x15, 0x13, 0x11,
66 0x0F, 0x0D, 0x0B, 0x09, 0x07, 0x05, 0x03, 0x01
67 );
68 for (; width > 16; width -= 32, sp += 32, dpl += 16, dph += 16)
69 {
70 __m512 a = _mm512_load_ps(sp);
71 __m512 b = _mm512_load_ps(sp + 16);
72 __m512 c = _mm512_permutex2var_ps(a, idx1, b);
73 __m512 d = _mm512_permutex2var_ps(a, idx2, b);
74 _mm512_store_ps(dpl, c);
75 _mm512_store_ps(dph, d);
76 }
77 for (; width > 0; width -= 16, sp += 16, dpl += 8, dph += 8)
78 {
79 __m256 a = _mm256_load_ps(sp);
80 __m256 b = _mm256_load_ps(sp + 8);
81 __m256 c = _mm256_permute2f128_ps(a, b, (2 << 4) | (0));
82 __m256 d = _mm256_permute2f128_ps(a, b, (3 << 4) | (1));
83 __m256 e = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(2, 0, 2, 0));
84 __m256 f = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(3, 1, 3, 1));
85 _mm256_store_ps(dpl, e);
86 _mm256_store_ps(dph, f);
87 }
88 }
89
91 // We split multiples of 32 followed by multiples of 16, because
92 // we assume byte_alignment == 64
93 static
94 void avx512_interleave32(float* dp, float* spl, float* sph, int width)
95 {
96 __m512i idx1 = _mm512_set_epi32(
97 0x17, 0x7, 0x16, 0x6, 0x15, 0x5, 0x14, 0x4,
98 0x13, 0x3, 0x12, 0x2, 0x11, 0x1, 0x10, 0x0
99 );
100 __m512i idx2 = _mm512_set_epi32(
101 0x1F, 0xF, 0x1E, 0xE, 0x1D, 0xD, 0x1C, 0xC,
102 0x1B, 0xB, 0x1A, 0xA, 0x19, 0x9, 0x18, 0x8
103 );
104 for (; width > 16; width -= 32, dp += 32, spl += 16, sph += 16)
105 {
106 __m512 a = _mm512_load_ps(spl);
107 __m512 b = _mm512_load_ps(sph);
108 __m512 c = _mm512_permutex2var_ps(a, idx1, b);
109 __m512 d = _mm512_permutex2var_ps(a, idx2, b);
110 _mm512_store_ps(dp, c);
111 _mm512_store_ps(dp + 16, d);
112 }
113 for (; width > 0; width -= 16, dp += 16, spl += 8, sph += 8)
114 {
115 __m256 a = _mm256_load_ps(spl);
116 __m256 b = _mm256_load_ps(sph);
117 __m256 c = _mm256_unpacklo_ps(a, b);
118 __m256 d = _mm256_unpackhi_ps(a, b);
119 __m256 e = _mm256_permute2f128_ps(c, d, (2 << 4) | (0));
120 __m256 f = _mm256_permute2f128_ps(c, d, (3 << 4) | (1));
121 _mm256_store_ps(dp, e);
122 _mm256_store_ps(dp + 8, f);
123 }
124 }
125
127 // We split multiples of 32 followed by multiples of 16, because
128 // we assume byte_alignment == 64
129 static void avx512_deinterleave64(double* dpl, double* dph, double* sp,
130 int width)
131 {
132 __m512i idx1 = _mm512_set_epi64(
133 0x0E, 0x0C, 0x0A, 0x08, 0x06, 0x04, 0x02, 0x00
134 );
135 __m512i idx2 = _mm512_set_epi64(
136 0x0F, 0x0D, 0x0B, 0x09, 0x07, 0x05, 0x03, 0x01
137 );
138 for (; width > 8; width -= 16, sp += 16, dpl += 8, dph += 8)
139 {
140 __m512d a = _mm512_load_pd(sp);
141 __m512d b = _mm512_load_pd(sp + 16);
142 __m512d c = _mm512_permutex2var_pd(a, idx1, b);
143 __m512d d = _mm512_permutex2var_pd(a, idx2, b);
144 _mm512_store_pd(dpl, c);
145 _mm512_store_pd(dph, d);
146 }
147 for (; width > 0; width -= 8, sp += 8, dpl += 4, dph += 4)
148 {
149 __m256d a = _mm256_load_pd(sp);
150 __m256d b = _mm256_load_pd(sp + 4);
151 __m256d c = _mm256_permute2f128_pd(a, b, (2 << 4) | (0));
152 __m256d d = _mm256_permute2f128_pd(a, b, (3 << 4) | (1));
153 __m256d e = _mm256_shuffle_pd(c, d, 0x0);
154 __m256d f = _mm256_shuffle_pd(c, d, 0xF);
155 _mm256_store_pd(dpl, e);
156 _mm256_store_pd(dph, f);
157 }
158 }
159
161 // We split multiples of 32 followed by multiples of 16, because
162 // we assume byte_alignment == 64
163 static void avx512_interleave64(double* dp, double* spl, double* sph,
164 int width)
165 {
166 __m512i idx1 = _mm512_set_epi64(
167 0xB, 0x3, 0xA, 0x2, 0x9, 0x1, 0x8, 0x0
168 );
169 __m512i idx2 = _mm512_set_epi64(
170 0xF, 0x7, 0xE, 0x6, 0xD, 0x5, 0xC, 0x4
171 );
172 for (; width > 8; width -= 16, dp += 16, spl += 8, sph += 8)
173 {
174 __m512d a = _mm512_load_pd(spl);
175 __m512d b = _mm512_load_pd(sph);
176 __m512d c = _mm512_permutex2var_pd(a, idx1, b);
177 __m512d d = _mm512_permutex2var_pd(a, idx2, b);
178 _mm512_store_pd(dp, c);
179 _mm512_store_pd(dp + 16, d);
180 }
181 for (; width > 0; width -= 8, dp += 8, spl += 4, sph += 4)
182 {
183 __m256d a = _mm256_load_pd(spl);
184 __m256d b = _mm256_load_pd(sph);
185 __m256d c = _mm256_unpacklo_pd(a, b);
186 __m256d d = _mm256_unpackhi_pd(a, b);
187 __m256d e = _mm256_permute2f128_pd(c, d, (2 << 4) | (0));
188 __m256d f = _mm256_permute2f128_pd(c, d, (3 << 4) | (1));
189 _mm256_store_pd(dp, e);
190 _mm256_store_pd(dp + 4, f);
191 }
192 }
193
195 static inline void avx512_multiply_const(float* p, float f, int width)
196 {
197 __m512 factor = _mm512_set1_ps(f);
198 for (; width > 0; width -= 16, p += 16)
199 {
200 __m512 s = _mm512_load_ps(p);
201 _mm512_store_ps(p, _mm512_mul_ps(factor, s));
202 }
203 }
204
206 void avx512_irv_vert_step(const lifting_step* s, const line_buf* sig,
207 const line_buf* other, const line_buf* aug,
208 ui32 repeat, bool synthesis)
209 {
210 float a = s->irv.Aatk;
211 if (synthesis)
212 a = -a;
213
214 __m512 factor = _mm512_set1_ps(a);
215
216 float* dst = aug->f32;
217 const float* src1 = sig->f32, * src2 = other->f32;
218 int i = (int)repeat;
219 for ( ; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
220 {
221 __m512 s1 = _mm512_load_ps(src1);
222 __m512 s2 = _mm512_load_ps(src2);
223 __m512 d = _mm512_load_ps(dst);
224 d = _mm512_add_ps(d, _mm512_mul_ps(factor, _mm512_add_ps(s1, s2)));
225 _mm512_store_ps(dst, d);
226 }
227 }
228
230 void avx512_irv_vert_times_K(float K, const line_buf* aug, ui32 repeat)
231 {
232 avx512_multiply_const(aug->f32, K, (int)repeat);
233 }
234
236 void avx512_irv_horz_ana(const param_atk* atk, const line_buf* ldst,
237 const line_buf* hdst, const line_buf* src,
238 ui32 width, bool even)
239 {
240 if (width > 1)
241 {
242 // split src into ldst and hdst
243 {
244 float* dpl = even ? ldst->f32 : hdst->f32;
245 float* dph = even ? hdst->f32 : ldst->f32;
246 float* sp = src->f32;
247 int w = (int)width;
248 avx512_deinterleave32(dpl, dph, sp, w);
249 }
250
251 // the actual horizontal transform
252 float* hp = hdst->f32, * lp = ldst->f32;
253 ui32 l_width = (width + (even ? 1 : 0)) >> 1; // low pass
254 ui32 h_width = (width + (even ? 0 : 1)) >> 1; // high pass
255 ui32 num_steps = atk->get_num_steps();
256 for (ui32 j = num_steps; j > 0; --j)
257 {
258 const lifting_step* s = atk->get_step(j - 1);
259 const float a = s->irv.Aatk;
260
261 // extension
262 lp[-1] = lp[0];
263 lp[l_width] = lp[l_width - 1];
264 // lifting step
265 const float* sp = lp;
266 float* dp = hp;
267 int i = (int)h_width;
268 __m512 f = _mm512_set1_ps(a);
269 if (even)
270 {
271 for (; i > 0; i -= 16, sp += 16, dp += 16)
272 {
273 __m512 m = _mm512_load_ps(sp);
274 __m512 n = _mm512_loadu_ps(sp + 1);
275 __m512 p = _mm512_load_ps(dp);
276 p = _mm512_add_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n)));
277 _mm512_store_ps(dp, p);
278 }
279 }
280 else
281 {
282 for (; i > 0; i -= 16, sp += 16, dp += 16)
283 {
284 __m512 m = _mm512_load_ps(sp);
285 __m512 n = _mm512_loadu_ps(sp - 1);
286 __m512 p = _mm512_load_ps(dp);
287 p = _mm512_add_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n)));
288 _mm512_store_ps(dp, p);
289 }
290 }
291
292 // swap buffers
293 float* t = lp; lp = hp; hp = t;
294 even = !even;
295 ui32 w = l_width; l_width = h_width; h_width = w;
296 }
297
298 { // multiply by K or 1/K
299 float K = atk->get_K();
300 float K_inv = 1.0f / K;
301 avx512_multiply_const(lp, K_inv, (int)l_width);
302 avx512_multiply_const(hp, K, (int)h_width);
303 }
304 }
305 else {
306 if (even)
307 ldst->f32[0] = src->f32[0];
308 else
309 hdst->f32[0] = src->f32[0] * 2.0f;
310 }
311 }
312
314 void avx512_irv_horz_syn(const param_atk* atk, const line_buf* dst,
315 const line_buf* lsrc, const line_buf* hsrc,
316 ui32 width, bool even)
317 {
318 if (width > 1)
319 {
320 bool ev = even;
321 float* oth = hsrc->f32, * aug = lsrc->f32;
322 ui32 aug_width = (width + (even ? 1 : 0)) >> 1; // low pass
323 ui32 oth_width = (width + (even ? 0 : 1)) >> 1; // high pass
324
325 { // multiply by K or 1/K
326 float K = atk->get_K();
327 float K_inv = 1.0f / K;
328 avx512_multiply_const(aug, K, (int)aug_width);
329 avx512_multiply_const(oth, K_inv, (int)oth_width);
330 }
331
332 // the actual horizontal transform
333 ui32 num_steps = atk->get_num_steps();
334 for (ui32 j = 0; j < num_steps; ++j)
335 {
336 const lifting_step* s = atk->get_step(j);
337 const float a = s->irv.Aatk;
338
339 // extension
340 oth[-1] = oth[0];
341 oth[oth_width] = oth[oth_width - 1];
342 // lifting step
343 const float* sp = oth;
344 float* dp = aug;
345 int i = (int)aug_width;
346 __m512 f = _mm512_set1_ps(a);
347 if (ev)
348 {
349 for (; i > 0; i -= 16, sp += 16, dp += 16)
350 {
351 __m512 m = _mm512_load_ps(sp);
352 __m512 n = _mm512_loadu_ps(sp - 1);
353 __m512 p = _mm512_load_ps(dp);
354 p = _mm512_sub_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n)));
355 _mm512_store_ps(dp, p);
356 }
357 }
358 else
359 {
360 for (; i > 0; i -= 16, sp += 16, dp += 16)
361 {
362 __m512 m = _mm512_load_ps(sp);
363 __m512 n = _mm512_loadu_ps(sp + 1);
364 __m512 p = _mm512_load_ps(dp);
365 p = _mm512_sub_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n)));
366 _mm512_store_ps(dp, p);
367 }
368 }
369
370 // swap buffers
371 float* t = aug; aug = oth; oth = t;
372 ev = !ev;
373 ui32 w = aug_width; aug_width = oth_width; oth_width = w;
374 }
375
376 // combine both lsrc and hsrc into dst
377 {
378 float* dp = dst->f32;
379 float* spl = even ? lsrc->f32 : hsrc->f32;
380 float* sph = even ? hsrc->f32 : lsrc->f32;
381 int w = (int)width;
382 avx512_interleave32(dp, spl, sph, w);
383 }
384 }
385 else {
386 if (even)
387 dst->f32[0] = lsrc->f32[0];
388 else
389 dst->f32[0] = hsrc->f32[0] * 0.5f;
390 }
391 }
392
393
395 void avx512_rev_vert_step32(const lifting_step* s, const line_buf* sig,
396 const line_buf* other, const line_buf* aug,
397 ui32 repeat, bool synthesis)
398 {
399 const si32 a = s->rev.Aatk;
400 const si32 b = s->rev.Batk;
401 const ui8 e = s->rev.Eatk;
402 __m512i va = _mm512_set1_epi32(a);
403 __m512i vb = _mm512_set1_epi32(b);
404
405 si32* dst = aug->i32;
406 const si32* src1 = sig->i32, * src2 = other->i32;
407 // The general definition of the wavelet in Part 2 is slightly
408 // different to part 2, although they are mathematically equivalent
409 // here, we identify the simpler form from Part 1 and employ them
410 if (a == 1)
411 { // 5/3 update and any case with a == 1
412 int i = (int)repeat;
413 if (synthesis)
414 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
415 {
416 __m512i s1 = _mm512_load_si512((__m512i*)src1);
417 __m512i s2 = _mm512_load_si512((__m512i*)src2);
418 __m512i d = _mm512_load_si512((__m512i*)dst);
419 __m512i t = _mm512_add_epi32(s1, s2);
420 __m512i v = _mm512_add_epi32(vb, t);
421 __m512i w = _mm512_srai_epi32(v, e);
422 d = _mm512_sub_epi32(d, w);
423 _mm512_store_si512((__m512i*)dst, d);
424 }
425 else
426 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
427 {
428 __m512i s1 = _mm512_load_si512((__m512i*)src1);
429 __m512i s2 = _mm512_load_si512((__m512i*)src2);
430 __m512i d = _mm512_load_si512((__m512i*)dst);
431 __m512i t = _mm512_add_epi32(s1, s2);
432 __m512i v = _mm512_add_epi32(vb, t);
433 __m512i w = _mm512_srai_epi32(v, e);
434 d = _mm512_add_epi32(d, w);
435 _mm512_store_si512((__m512i*)dst, d);
436 }
437 }
438 else if (a == -1 && b == 1 && e == 1)
439 { // 5/3 predict
440 int i = (int)repeat;
441 if (synthesis)
442 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
443 {
444 __m512i s1 = _mm512_load_si512((__m512i*)src1);
445 __m512i s2 = _mm512_load_si512((__m512i*)src2);
446 __m512i d = _mm512_load_si512((__m512i*)dst);
447 __m512i t = _mm512_add_epi32(s1, s2);
448 __m512i w = _mm512_srai_epi32(t, e);
449 d = _mm512_add_epi32(d, w);
450 _mm512_store_si512((__m512i*)dst, d);
451 }
452 else
453 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
454 {
455 __m512i s1 = _mm512_load_si512((__m512i*)src1);
456 __m512i s2 = _mm512_load_si512((__m512i*)src2);
457 __m512i d = _mm512_load_si512((__m512i*)dst);
458 __m512i t = _mm512_add_epi32(s1, s2);
459 __m512i w = _mm512_srai_epi32(t, e);
460 d = _mm512_sub_epi32(d, w);
461 _mm512_store_si512((__m512i*)dst, d);
462 }
463 }
464 else if (a == -1)
465 { // any case with a == -1, which is not 5/3 predict
466 int i = (int)repeat;
467 if (synthesis)
468 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
469 {
470 __m512i s1 = _mm512_load_si512((__m512i*)src1);
471 __m512i s2 = _mm512_load_si512((__m512i*)src2);
472 __m512i d = _mm512_load_si512((__m512i*)dst);
473 __m512i t = _mm512_add_epi32(s1, s2);
474 __m512i v = _mm512_sub_epi32(vb, t);
475 __m512i w = _mm512_srai_epi32(v, e);
476 d = _mm512_sub_epi32(d, w);
477 _mm512_store_si512((__m512i*)dst, d);
478 }
479 else
480 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
481 {
482 __m512i s1 = _mm512_load_si512((__m512i*)src1);
483 __m512i s2 = _mm512_load_si512((__m512i*)src2);
484 __m512i d = _mm512_load_si512((__m512i*)dst);
485 __m512i t = _mm512_add_epi32(s1, s2);
486 __m512i v = _mm512_sub_epi32(vb, t);
487 __m512i w = _mm512_srai_epi32(v, e);
488 d = _mm512_add_epi32(d, w);
489 _mm512_store_si512((__m512i*)dst, d);
490 }
491 }
492 else { // general case
493 int i = (int)repeat;
494 if (synthesis)
495 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
496 {
497 __m512i s1 = _mm512_load_si512((__m512i*)src1);
498 __m512i s2 = _mm512_load_si512((__m512i*)src2);
499 __m512i d = _mm512_load_si512((__m512i*)dst);
500 __m512i t = _mm512_add_epi32(s1, s2);
501 __m512i u = _mm512_mullo_epi32(va, t);
502 __m512i v = _mm512_add_epi32(vb, u);
503 __m512i w = _mm512_srai_epi32(v, e);
504 d = _mm512_sub_epi32(d, w);
505 _mm512_store_si512((__m512i*)dst, d);
506 }
507 else
508 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
509 {
510 __m512i s1 = _mm512_load_si512((__m512i*)src1);
511 __m512i s2 = _mm512_load_si512((__m512i*)src2);
512 __m512i d = _mm512_load_si512((__m512i*)dst);
513 __m512i t = _mm512_add_epi32(s1, s2);
514 __m512i u = _mm512_mullo_epi32(va, t);
515 __m512i v = _mm512_add_epi32(vb, u);
516 __m512i w = _mm512_srai_epi32(v, e);
517 d = _mm512_add_epi32(d, w);
518 _mm512_store_si512((__m512i*)dst, d);
519 }
520 }
521 }
522
524 void avx512_rev_vert_step64(const lifting_step* s, const line_buf* sig,
525 const line_buf* other, const line_buf* aug,
526 ui32 repeat, bool synthesis)
527 {
528 const si32 a = s->rev.Aatk;
529 const si32 b = s->rev.Batk;
530 const ui8 e = s->rev.Eatk;
531 __m512i vb = _mm512_set1_epi64(b);
532
533 si64* dst = aug->i64;
534 const si64* src1 = sig->i64, * src2 = other->i64;
535 // The general definition of the wavelet in Part 2 is slightly
536 // different to part 2, although they are mathematically equivalent
537 // here, we identify the simpler form from Part 1 and employ them
538 if (a == 1)
539 { // 5/3 update and any case with a == 1
540 int i = (int)repeat;
541 if (synthesis)
542 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
543 {
544 __m512i s1 = _mm512_load_si512((__m512i*)src1);
545 __m512i s2 = _mm512_load_si512((__m512i*)src2);
546 __m512i d = _mm512_load_si512((__m512i*)dst);
547 __m512i t = _mm512_add_epi64(s1, s2);
548 __m512i v = _mm512_add_epi64(vb, t);
549 __m512i w = _mm512_srai_epi64(v, e);
550 d = _mm512_sub_epi64(d, w);
551 _mm512_store_si512((__m512i*)dst, d);
552 }
553 else
554 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
555 {
556 __m512i s1 = _mm512_load_si512((__m512i*)src1);
557 __m512i s2 = _mm512_load_si512((__m512i*)src2);
558 __m512i d = _mm512_load_si512((__m512i*)dst);
559 __m512i t = _mm512_add_epi64(s1, s2);
560 __m512i v = _mm512_add_epi64(vb, t);
561 __m512i w = _mm512_srai_epi64(v, e);
562 d = _mm512_add_epi64(d, w);
563 _mm512_store_si512((__m512i*)dst, d);
564 }
565 }
566 else if (a == -1 && b == 1 && e == 1)
567 { // 5/3 predict
568 int i = (int)repeat;
569 if (synthesis)
570 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
571 {
572 __m512i s1 = _mm512_load_si512((__m512i*)src1);
573 __m512i s2 = _mm512_load_si512((__m512i*)src2);
574 __m512i d = _mm512_load_si512((__m512i*)dst);
575 __m512i t = _mm512_add_epi64(s1, s2);
576 __m512i w = _mm512_srai_epi64(t, e);
577 d = _mm512_add_epi64(d, w);
578 _mm512_store_si512((__m512i*)dst, d);
579 }
580 else
581 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
582 {
583 __m512i s1 = _mm512_load_si512((__m512i*)src1);
584 __m512i s2 = _mm512_load_si512((__m512i*)src2);
585 __m512i d = _mm512_load_si512((__m512i*)dst);
586 __m512i t = _mm512_add_epi64(s1, s2);
587 __m512i w = _mm512_srai_epi64(t, e);
588 d = _mm512_sub_epi64(d, w);
589 _mm512_store_si512((__m512i*)dst, d);
590 }
591 }
592 else if (a == -1)
593 { // any case with a == -1, which is not 5/3 predict
594 int i = (int)repeat;
595 if (synthesis)
596 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
597 {
598 __m512i s1 = _mm512_load_si512((__m512i*)src1);
599 __m512i s2 = _mm512_load_si512((__m512i*)src2);
600 __m512i d = _mm512_load_si512((__m512i*)dst);
601 __m512i t = _mm512_add_epi64(s1, s2);
602 __m512i v = _mm512_sub_epi64(vb, t);
603 __m512i w = _mm512_srai_epi64(v, e);
604 d = _mm512_sub_epi64(d, w);
605 _mm512_store_si512((__m512i*)dst, d);
606 }
607 else
608 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
609 {
610 __m512i s1 = _mm512_load_si512((__m512i*)src1);
611 __m512i s2 = _mm512_load_si512((__m512i*)src2);
612 __m512i d = _mm512_load_si512((__m512i*)dst);
613 __m512i t = _mm512_add_epi64(s1, s2);
614 __m512i v = _mm512_sub_epi64(vb, t);
615 __m512i w = _mm512_srai_epi64(v, e);
616 d = _mm512_add_epi64(d, w);
617 _mm512_store_si512((__m512i*)dst, d);
618 }
619 }
620 else {
621 // general case
622 // 64bit multiplication is not supported in AVX512F + AVX512CD;
623 // in particular, _mm256_mullo_epi64.
624 if (synthesis)
625 for (ui32 i = repeat; i > 0; --i)
626 *dst++ -= (b + a * (*src1++ + *src2++)) >> e;
627 else
628 for (ui32 i = repeat; i > 0; --i)
629 *dst++ += (b + a * (*src1++ + *src2++)) >> e;
630 }
631
632 // This can only be used if you have AVX512DQ
633 // { // general case
634 // __m512i va = _mm512_set1_epi64(a);
635 // int i = (int)repeat;
636 // if (synthesis)
637 // for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
638 // {
639 // __m512i s1 = _mm512_load_si512((__m512i*)src1);
640 // __m512i s2 = _mm512_load_si512((__m512i*)src2);
641 // __m512i d = _mm512_load_si512((__m512i*)dst);
642 // __m512i t = _mm512_add_epi64(s1, s2);
643 // __m512i u = _mm512_mullo_epi64(va, t);
644 // __m512i v = _mm512_add_epi64(vb, u);
645 // __m512i w = _mm512_srai_epi64(v, e);
646 // d = _mm512_sub_epi64(d, w);
647 // _mm512_store_si512((__m512i*)dst, d);
648 // }
649 // else
650 // for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
651 // {
652 // __m512i s1 = _mm512_load_si512((__m512i*)src1);
653 // __m512i s2 = _mm512_load_si512((__m512i*)src2);
654 // __m512i d = _mm512_load_si512((__m512i*)dst);
655 // __m512i t = _mm512_add_epi64(s1, s2);
656 // __m512i u = _mm512_mullo_epi64(va, t);
657 // __m512i v = _mm512_add_epi64(vb, u);
658 // __m512i w = _mm512_srai_epi64(v, e);
659 // d = _mm512_add_epi64(d, w);
660 // _mm512_store_si512((__m512i*)dst, d);
661 // }
662 // }
663 }
664
666 void avx512_rev_vert_step(const lifting_step* s, const line_buf* sig,
667 const line_buf* other, const line_buf* aug,
668 ui32 repeat, bool synthesis)
669 {
670 if (((sig != NULL) && (sig->flags & line_buf::LFT_32BIT)) ||
671 ((aug != NULL) && (aug->flags & line_buf::LFT_32BIT)) ||
672 ((other != NULL) && (other->flags & line_buf::LFT_32BIT)))
673 {
674 assert((sig == NULL || sig->flags & line_buf::LFT_32BIT) &&
675 (other == NULL || other->flags & line_buf::LFT_32BIT) &&
676 (aug == NULL || aug->flags & line_buf::LFT_32BIT));
677 avx512_rev_vert_step32(s, sig, other, aug, repeat, synthesis);
678 }
679 else
680 {
681 assert((sig == NULL || sig->flags & line_buf::LFT_64BIT) &&
682 (other == NULL || other->flags & line_buf::LFT_64BIT) &&
683 (aug == NULL || aug->flags & line_buf::LFT_64BIT));
684 avx512_rev_vert_step64(s, sig, other, aug, repeat, synthesis);
685 }
686 }
687
689 void avx512_rev_horz_ana32(const param_atk* atk, const line_buf* ldst,
690 const line_buf* hdst, const line_buf* src,
691 ui32 width, bool even)
692 {
693 if (width > 1)
694 {
695 // split src into ldst and hdst
696 {
697 float* dpl = even ? ldst->f32 : hdst->f32;
698 float* dph = even ? hdst->f32 : ldst->f32;
699 float* sp = src->f32;
700 int w = (int)width;
701 avx512_deinterleave32(dpl, dph, sp, w);
702 }
703
704 si32* hp = hdst->i32, * lp = ldst->i32;
705 ui32 l_width = (width + (even ? 1 : 0)) >> 1; // low pass
706 ui32 h_width = (width + (even ? 0 : 1)) >> 1; // high pass
707 ui32 num_steps = atk->get_num_steps();
708 for (ui32 j = num_steps; j > 0; --j)
709 {
710 // first lifting step
711 const lifting_step* s = atk->get_step(j - 1);
712 const si32 a = s->rev.Aatk;
713 const si32 b = s->rev.Batk;
714 const ui8 e = s->rev.Eatk;
715 __m512i va = _mm512_set1_epi32(a);
716 __m512i vb = _mm512_set1_epi32(b);
717
718 // extension
719 lp[-1] = lp[0];
720 lp[l_width] = lp[l_width - 1];
721 // lifting step
722 const si32* sp = lp;
723 si32* dp = hp;
724 if (a == 1)
725 { // 5/3 update and any case with a == 1
726 int i = (int)h_width;
727 if (even)
728 {
729 for (; i > 0; i -= 16, sp += 16, dp += 16)
730 {
731 __m512i s1 = _mm512_load_si512((__m512i*)sp);
732 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
733 __m512i d = _mm512_load_si512((__m512i*)dp);
734 __m512i t = _mm512_add_epi32(s1, s2);
735 __m512i v = _mm512_add_epi32(vb, t);
736 __m512i w = _mm512_srai_epi32(v, e);
737 d = _mm512_add_epi32(d, w);
738 _mm512_store_si512((__m512i*)dp, d);
739 }
740 }
741 else
742 {
743 for (; i > 0; i -= 16, sp += 16, dp += 16)
744 {
745 __m512i s1 = _mm512_load_si512((__m512i*)sp);
746 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
747 __m512i d = _mm512_load_si512((__m512i*)dp);
748 __m512i t = _mm512_add_epi32(s1, s2);
749 __m512i v = _mm512_add_epi32(vb, t);
750 __m512i w = _mm512_srai_epi32(v, e);
751 d = _mm512_add_epi32(d, w);
752 _mm512_store_si512((__m512i*)dp, d);
753 }
754 }
755 }
756 else if (a == -1 && b == 1 && e == 1)
757 { // 5/3 predict
758 int i = (int)h_width;
759 if (even)
760 for (; i > 0; i -= 16, sp += 16, dp += 16)
761 {
762 __m512i s1 = _mm512_load_si512((__m512i*)sp);
763 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
764 __m512i d = _mm512_load_si512((__m512i*)dp);
765 __m512i t = _mm512_add_epi32(s1, s2);
766 __m512i w = _mm512_srai_epi32(t, e);
767 d = _mm512_sub_epi32(d, w);
768 _mm512_store_si512((__m512i*)dp, d);
769 }
770 else
771 for (; i > 0; i -= 16, sp += 16, dp += 16)
772 {
773 __m512i s1 = _mm512_load_si512((__m512i*)sp);
774 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
775 __m512i d = _mm512_load_si512((__m512i*)dp);
776 __m512i t = _mm512_add_epi32(s1, s2);
777 __m512i w = _mm512_srai_epi32(t, e);
778 d = _mm512_sub_epi32(d, w);
779 _mm512_store_si512((__m512i*)dp, d);
780 }
781 }
782 else if (a == -1)
783 { // any case with a == -1, which is not 5/3 predict
784 int i = (int)h_width;
785 if (even)
786 for (; i > 0; i -= 16, sp += 16, dp += 16)
787 {
788 __m512i s1 = _mm512_load_si512((__m512i*)sp);
789 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
790 __m512i d = _mm512_load_si512((__m512i*)dp);
791 __m512i t = _mm512_add_epi32(s1, s2);
792 __m512i v = _mm512_sub_epi32(vb, t);
793 __m512i w = _mm512_srai_epi32(v, e);
794 d = _mm512_add_epi32(d, w);
795 _mm512_store_si512((__m512i*)dp, d);
796 }
797 else
798 for (; i > 0; i -= 16, sp += 16, dp += 16)
799 {
800 __m512i s1 = _mm512_load_si512((__m512i*)sp);
801 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
802 __m512i d = _mm512_load_si512((__m512i*)dp);
803 __m512i t = _mm512_add_epi32(s1, s2);
804 __m512i v = _mm512_sub_epi32(vb, t);
805 __m512i w = _mm512_srai_epi32(v, e);
806 d = _mm512_add_epi32(d, w);
807 _mm512_store_si512((__m512i*)dp, d);
808 }
809 }
810 else {
811 // general case
812 int i = (int)h_width;
813 if (even)
814 for (; i > 0; i -= 16, sp += 16, dp += 16)
815 {
816 __m512i s1 = _mm512_load_si512((__m512i*)sp);
817 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
818 __m512i d = _mm512_load_si512((__m512i*)dp);
819 __m512i t = _mm512_add_epi32(s1, s2);
820 __m512i u = _mm512_mullo_epi32(va, t);
821 __m512i v = _mm512_add_epi32(vb, u);
822 __m512i w = _mm512_srai_epi32(v, e);
823 d = _mm512_add_epi32(d, w);
824 _mm512_store_si512((__m512i*)dp, d);
825 }
826 else
827 for (; i > 0; i -= 16, sp += 16, dp += 16)
828 {
829 __m512i s1 = _mm512_load_si512((__m512i*)sp);
830 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
831 __m512i d = _mm512_load_si512((__m512i*)dp);
832 __m512i t = _mm512_add_epi32(s1, s2);
833 __m512i u = _mm512_mullo_epi32(va, t);
834 __m512i v = _mm512_add_epi32(vb, u);
835 __m512i w = _mm512_srai_epi32(v, e);
836 d = _mm512_add_epi32(d, w);
837 _mm512_store_si512((__m512i*)dp, d);
838 }
839 }
840
841 // swap buffers
842 si32* t = lp; lp = hp; hp = t;
843 even = !even;
844 ui32 w = l_width; l_width = h_width; h_width = w;
845 }
846 }
847 else {
848 if (even)
849 ldst->i32[0] = src->i32[0];
850 else
851 hdst->i32[0] = src->i32[0] << 1;
852 }
853 }
854
856 void avx512_rev_horz_ana64(const param_atk* atk, const line_buf* ldst,
857 const line_buf* hdst, const line_buf* src,
858 ui32 width, bool even)
859 {
860 if (width > 1)
861 {
862 // split src into ldst and hdst
863 {
864 double* dpl = (double*)(even ? ldst->p : hdst->p);
865 double* dph = (double*)(even ? hdst->p : ldst->p);
866 double* sp = (double*)(src->p);
867 int w = (int)width;
868 avx512_deinterleave64(dpl, dph, sp, w);
869 }
870
871 si64* hp = hdst->i64, * lp = ldst->i64;
872 ui32 l_width = (width + (even ? 1 : 0)) >> 1; // low pass
873 ui32 h_width = (width + (even ? 0 : 1)) >> 1; // high pass
874 ui32 num_steps = atk->get_num_steps();
875 for (ui32 j = num_steps; j > 0; --j)
876 {
877 // first lifting step
878 const lifting_step* s = atk->get_step(j - 1);
879 const si32 a = s->rev.Aatk;
880 const si32 b = s->rev.Batk;
881 const ui8 e = s->rev.Eatk;
882 __m512i vb = _mm512_set1_epi64(b);
883
884 // extension
885 lp[-1] = lp[0];
886 lp[l_width] = lp[l_width - 1];
887 // lifting step
888 const si64* sp = lp;
889 si64* dp = hp;
890 if (a == 1)
891 { // 5/3 update and any case with a == 1
892 int i = (int)h_width;
893 if (even)
894 {
895 for (; i > 0; i -= 8, sp += 8, dp += 8)
896 {
897 __m512i s1 = _mm512_load_si512((__m512i*)sp);
898 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
899 __m512i d = _mm512_load_si512((__m512i*)dp);
900 __m512i t = _mm512_add_epi64(s1, s2);
901 __m512i v = _mm512_add_epi64(vb, t);
902 __m512i w = _mm512_srai_epi64(v, e);
903 d = _mm512_add_epi64(d, w);
904 _mm512_store_si512((__m512i*)dp, d);
905 }
906 }
907 else
908 {
909 for (; i > 0; i -= 8, sp += 8, dp += 8)
910 {
911 __m512i s1 = _mm512_load_si512((__m512i*)sp);
912 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
913 __m512i d = _mm512_load_si512((__m512i*)dp);
914 __m512i t = _mm512_add_epi64(s1, s2);
915 __m512i v = _mm512_add_epi64(vb, t);
916 __m512i w = _mm512_srai_epi64(v, e);
917 d = _mm512_add_epi64(d, w);
918 _mm512_store_si512((__m512i*)dp, d);
919 }
920 }
921 }
922 else if (a == -1 && b == 1 && e == 1)
923 { // 5/3 predict
924 int i = (int)h_width;
925 if (even)
926 for (; i > 0; i -= 8, sp += 8, dp += 8)
927 {
928 __m512i s1 = _mm512_load_si512((__m512i*)sp);
929 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
930 __m512i d = _mm512_load_si512((__m512i*)dp);
931 __m512i t = _mm512_add_epi64(s1, s2);
932 __m512i w = _mm512_srai_epi64(t, e);
933 d = _mm512_sub_epi64(d, w);
934 _mm512_store_si512((__m512i*)dp, d);
935 }
936 else
937 for (; i > 0; i -= 8, sp += 8, dp += 8)
938 {
939 __m512i s1 = _mm512_load_si512((__m512i*)sp);
940 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
941 __m512i d = _mm512_load_si512((__m512i*)dp);
942 __m512i t = _mm512_add_epi64(s1, s2);
943 __m512i w = _mm512_srai_epi64(t, e);
944 d = _mm512_sub_epi64(d, w);
945 _mm512_store_si512((__m512i*)dp, d);
946 }
947 }
948 else if (a == -1)
949 { // any case with a == -1, which is not 5/3 predict
950 int i = (int)h_width;
951 if (even)
952 for (; i > 0; i -= 8, sp += 8, dp += 8)
953 {
954 __m512i s1 = _mm512_load_si512((__m512i*)sp);
955 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
956 __m512i d = _mm512_load_si512((__m512i*)dp);
957 __m512i t = _mm512_add_epi64(s1, s2);
958 __m512i v = _mm512_sub_epi64(vb, t);
959 __m512i w = _mm512_srai_epi64(v, e);
960 d = _mm512_add_epi64(d, w);
961 _mm512_store_si512((__m512i*)dp, d);
962 }
963 else
964 for (; i > 0; i -= 8, sp += 8, dp += 8)
965 {
966 __m512i s1 = _mm512_load_si512((__m512i*)sp);
967 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
968 __m512i d = _mm512_load_si512((__m512i*)dp);
969 __m512i t = _mm512_add_epi64(s1, s2);
970 __m512i v = _mm512_sub_epi64(vb, t);
971 __m512i w = _mm512_srai_epi64(v, e);
972 d = _mm512_add_epi64(d, w);
973 _mm512_store_si512((__m512i*)dp, d);
974 }
975 }
976 else
977 {
978 // general case
979 // 64bit multiplication is not supported in AVX512F + AVX512CD;
980 // in particular, _mm256_mullo_epi64.
981 if (even)
982 for (ui32 i = h_width; i > 0; --i, sp++, dp++)
983 *dp += (b + a * (sp[0] + sp[1])) >> e;
984 else
985 for (ui32 i = h_width; i > 0; --i, sp++, dp++)
986 *dp += (b + a * (sp[-1] + sp[0])) >> e;
987 }
988
989 // This can only be used if you have AVX512DQ
990 // {
991 // // general case
992 // __m512i va = _mm512_set1_epi64(a);
993 // int i = (int)h_width;
994 // if (even)
995 // for (; i > 0; i -= 8, sp += 8, dp += 8)
996 // {
997 // __m512i s1 = _mm512_load_si512((__m512i*)sp);
998 // __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
999 // __m512i d = _mm512_load_si512((__m512i*)dp);
1000 // __m512i t = _mm512_add_epi64(s1, s2);
1001 // __m512i u = _mm512_mullo_epi64(va, t);
1002 // __m512i v = _mm512_add_epi64(vb, u);
1003 // __m512i w = _mm512_srai_epi64(v, e);
1004 // d = _mm512_add_epi64(d, w);
1005 // _mm512_store_si512((__m512i*)dp, d);
1006 // }
1007 // else
1008 // for (; i > 0; i -= 8, sp += 8, dp += 8)
1009 // {
1010 // __m512i s1 = _mm512_load_si512((__m512i*)sp);
1011 // __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1012 // __m512i d = _mm512_load_si512((__m512i*)dp);
1013 // __m512i t = _mm512_add_epi64(s1, s2);
1014 // __m512i u = _mm512_mullo_epi64(va, t);
1015 // __m512i v = _mm512_add_epi64(vb, u);
1016 // __m512i w = _mm512_srai_epi64(v, e);
1017 // d = _mm512_add_epi64(d, w);
1018 // _mm512_store_si512((__m512i*)dp, d);
1019 // }
1020 // }
1021
1022 // swap buffers
1023 si64* t = lp; lp = hp; hp = t;
1024 even = !even;
1025 ui32 w = l_width; l_width = h_width; h_width = w;
1026 }
1027 }
1028 else {
1029 if (even)
1030 ldst->i64[0] = src->i64[0];
1031 else
1032 hdst->i64[0] = src->i64[0] << 1;
1033 }
1034 }
1035
1037 void avx512_rev_horz_ana(const param_atk* atk, const line_buf* ldst,
1038 const line_buf* hdst, const line_buf* src,
1039 ui32 width, bool even)
1040 {
1041 if (src->flags & line_buf::LFT_32BIT)
1042 {
1043 assert((ldst == NULL || ldst->flags & line_buf::LFT_32BIT) &&
1044 (hdst == NULL || hdst->flags & line_buf::LFT_32BIT));
1045 avx512_rev_horz_ana32(atk, ldst, hdst, src, width, even);
1046 }
1047 else
1048 {
1049 assert((ldst == NULL || ldst->flags & line_buf::LFT_64BIT) &&
1050 (hdst == NULL || hdst->flags & line_buf::LFT_64BIT) &&
1051 (src == NULL || src->flags & line_buf::LFT_64BIT));
1052 avx512_rev_horz_ana64(atk, ldst, hdst, src, width, even);
1053 }
1054 }
1055
1057 void avx512_rev_horz_syn32(const param_atk* atk, const line_buf* dst,
1058 const line_buf* lsrc, const line_buf* hsrc,
1059 ui32 width, bool even)
1060 {
1061 if (width > 1)
1062 {
1063 bool ev = even;
1064 si32* oth = hsrc->i32, * aug = lsrc->i32;
1065 ui32 aug_width = (width + (even ? 1 : 0)) >> 1; // low pass
1066 ui32 oth_width = (width + (even ? 0 : 1)) >> 1; // high pass
1067 ui32 num_steps = atk->get_num_steps();
1068 for (ui32 j = 0; j < num_steps; ++j)
1069 {
1070 const lifting_step* s = atk->get_step(j);
1071 const si32 a = s->rev.Aatk;
1072 const si32 b = s->rev.Batk;
1073 const ui8 e = s->rev.Eatk;
1074 __m512i va = _mm512_set1_epi32(a);
1075 __m512i vb = _mm512_set1_epi32(b);
1076
1077 // extension
1078 oth[-1] = oth[0];
1079 oth[oth_width] = oth[oth_width - 1];
1080 // lifting step
1081 const si32* sp = oth;
1082 si32* dp = aug;
1083 if (a == 1)
1084 { // 5/3 update and any case with a == 1
1085 int i = (int)aug_width;
1086 if (ev)
1087 {
1088 for (; i > 0; i -= 16, sp += 16, dp += 16)
1089 {
1090 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1091 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1092 __m512i d = _mm512_load_si512((__m512i*)dp);
1093 __m512i t = _mm512_add_epi32(s1, s2);
1094 __m512i v = _mm512_add_epi32(vb, t);
1095 __m512i w = _mm512_srai_epi32(v, e);
1096 d = _mm512_sub_epi32(d, w);
1097 _mm512_store_si512((__m512i*)dp, d);
1098 }
1099 }
1100 else
1101 {
1102 for (; i > 0; i -= 16, sp += 16, dp += 16)
1103 {
1104 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1105 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1106 __m512i d = _mm512_load_si512((__m512i*)dp);
1107 __m512i t = _mm512_add_epi32(s1, s2);
1108 __m512i v = _mm512_add_epi32(vb, t);
1109 __m512i w = _mm512_srai_epi32(v, e);
1110 d = _mm512_sub_epi32(d, w);
1111 _mm512_store_si512((__m512i*)dp, d);
1112 }
1113 }
1114 }
1115 else if (a == -1 && b == 1 && e == 1)
1116 { // 5/3 predict
1117 int i = (int)aug_width;
1118 if (ev)
1119 for (; i > 0; i -= 16, sp += 16, dp += 16)
1120 {
1121 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1122 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1123 __m512i d = _mm512_load_si512((__m512i*)dp);
1124 __m512i t = _mm512_add_epi32(s1, s2);
1125 __m512i w = _mm512_srai_epi32(t, e);
1126 d = _mm512_add_epi32(d, w);
1127 _mm512_store_si512((__m512i*)dp, d);
1128 }
1129 else
1130 for (; i > 0; i -= 16, sp += 16, dp += 16)
1131 {
1132 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1133 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1134 __m512i d = _mm512_load_si512((__m512i*)dp);
1135 __m512i t = _mm512_add_epi32(s1, s2);
1136 __m512i w = _mm512_srai_epi32(t, e);
1137 d = _mm512_add_epi32(d, w);
1138 _mm512_store_si512((__m512i*)dp, d);
1139 }
1140 }
1141 else if (a == -1)
1142 { // any case with a == -1, which is not 5/3 predict
1143 int i = (int)aug_width;
1144 if (ev)
1145 for (; i > 0; i -= 16, sp += 16, dp += 16)
1146 {
1147 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1148 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1149 __m512i d = _mm512_load_si512((__m512i*)dp);
1150 __m512i t = _mm512_add_epi32(s1, s2);
1151 __m512i v = _mm512_sub_epi32(vb, t);
1152 __m512i w = _mm512_srai_epi32(v, e);
1153 d = _mm512_sub_epi32(d, w);
1154 _mm512_store_si512((__m512i*)dp, d);
1155 }
1156 else
1157 for (; i > 0; i -= 16, sp += 16, dp += 16)
1158 {
1159 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1160 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1161 __m512i d = _mm512_load_si512((__m512i*)dp);
1162 __m512i t = _mm512_add_epi32(s1, s2);
1163 __m512i v = _mm512_sub_epi32(vb, t);
1164 __m512i w = _mm512_srai_epi32(v, e);
1165 d = _mm512_sub_epi32(d, w);
1166 _mm512_store_si512((__m512i*)dp, d);
1167 }
1168 }
1169 else {
1170 // general case
1171 int i = (int)aug_width;
1172 if (ev)
1173 for (; i > 0; i -= 16, sp += 16, dp += 16)
1174 {
1175 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1176 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1177 __m512i d = _mm512_load_si512((__m512i*)dp);
1178 __m512i t = _mm512_add_epi32(s1, s2);
1179 __m512i u = _mm512_mullo_epi32(va, t);
1180 __m512i v = _mm512_add_epi32(vb, u);
1181 __m512i w = _mm512_srai_epi32(v, e);
1182 d = _mm512_sub_epi32(d, w);
1183 _mm512_store_si512((__m512i*)dp, d);
1184 }
1185 else
1186 for (; i > 0; i -= 16, sp += 16, dp += 16)
1187 {
1188 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1189 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1190 __m512i d = _mm512_load_si512((__m512i*)dp);
1191 __m512i t = _mm512_add_epi32(s1, s2);
1192 __m512i u = _mm512_mullo_epi32(va, t);
1193 __m512i v = _mm512_add_epi32(vb, u);
1194 __m512i w = _mm512_srai_epi32(v, e);
1195 d = _mm512_sub_epi32(d, w);
1196 _mm512_store_si512((__m512i*)dp, d);
1197 }
1198 }
1199
1200 // swap buffers
1201 si32* t = aug; aug = oth; oth = t;
1202 ev = !ev;
1203 ui32 w = aug_width; aug_width = oth_width; oth_width = w;
1204 }
1205
1206 // combine both lsrc and hsrc into dst
1207 {
1208 float* dp = dst->f32;
1209 float* spl = even ? lsrc->f32 : hsrc->f32;
1210 float* sph = even ? hsrc->f32 : lsrc->f32;
1211 int w = (int)width;
1212 avx512_interleave32(dp, spl, sph, w);
1213 }
1214 }
1215 else {
1216 if (even)
1217 dst->i32[0] = lsrc->i32[0];
1218 else
1219 dst->i32[0] = hsrc->i32[0] >> 1;
1220 }
1221 }
1222
1224 void avx512_rev_horz_syn64(const param_atk* atk, const line_buf* dst,
1225 const line_buf* lsrc, const line_buf* hsrc,
1226 ui32 width, bool even)
1227 {
1228 if (width > 1)
1229 {
1230 bool ev = even;
1231 si64* oth = hsrc->i64, * aug = lsrc->i64;
1232 ui32 aug_width = (width + (even ? 1 : 0)) >> 1; // low pass
1233 ui32 oth_width = (width + (even ? 0 : 1)) >> 1; // high pass
1234 ui32 num_steps = atk->get_num_steps();
1235 for (ui32 j = 0; j < num_steps; ++j)
1236 {
1237 const lifting_step* s = atk->get_step(j);
1238 const si32 a = s->rev.Aatk;
1239 const si32 b = s->rev.Batk;
1240 const ui8 e = s->rev.Eatk;
1241 __m512i vb = _mm512_set1_epi64(b);
1242
1243 // extension
1244 oth[-1] = oth[0];
1245 oth[oth_width] = oth[oth_width - 1];
1246 // lifting step
1247 const si64* sp = oth;
1248 si64* dp = aug;
1249 if (a == 1)
1250 { // 5/3 update and any case with a == 1
1251 int i = (int)aug_width;
1252 if (ev)
1253 {
1254 for (; i > 0; i -= 8, sp += 8, dp += 8)
1255 {
1256 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1257 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1258 __m512i d = _mm512_load_si512((__m512i*)dp);
1259 __m512i t = _mm512_add_epi64(s1, s2);
1260 __m512i v = _mm512_add_epi64(vb, t);
1261 __m512i w = _mm512_srai_epi64(v, e);
1262 d = _mm512_sub_epi64(d, w);
1263 _mm512_store_si512((__m512i*)dp, d);
1264 }
1265 }
1266 else
1267 {
1268 for (; i > 0; i -= 8, sp += 8, dp += 8)
1269 {
1270 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1271 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1272 __m512i d = _mm512_load_si512((__m512i*)dp);
1273 __m512i t = _mm512_add_epi64(s1, s2);
1274 __m512i v = _mm512_add_epi64(vb, t);
1275 __m512i w = _mm512_srai_epi64(v, e);
1276 d = _mm512_sub_epi64(d, w);
1277 _mm512_store_si512((__m512i*)dp, d);
1278 }
1279 }
1280 }
1281 else if (a == -1 && b == 1 && e == 1)
1282 { // 5/3 predict
1283 int i = (int)aug_width;
1284 if (ev)
1285 for (; i > 0; i -= 8, sp += 8, dp += 8)
1286 {
1287 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1288 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1289 __m512i d = _mm512_load_si512((__m512i*)dp);
1290 __m512i t = _mm512_add_epi64(s1, s2);
1291 __m512i w = _mm512_srai_epi64(t, e);
1292 d = _mm512_add_epi64(d, w);
1293 _mm512_store_si512((__m512i*)dp, d);
1294 }
1295 else
1296 for (; i > 0; i -= 8, sp += 8, dp += 8)
1297 {
1298 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1299 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1300 __m512i d = _mm512_load_si512((__m512i*)dp);
1301 __m512i t = _mm512_add_epi64(s1, s2);
1302 __m512i w = _mm512_srai_epi64(t, e);
1303 d = _mm512_add_epi64(d, w);
1304 _mm512_store_si512((__m512i*)dp, d);
1305 }
1306 }
1307 else if (a == -1)
1308 { // any case with a == -1, which is not 5/3 predict
1309 int i = (int)aug_width;
1310 if (ev)
1311 for (; i > 0; i -= 8, sp += 8, dp += 8)
1312 {
1313 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1314 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1315 __m512i d = _mm512_load_si512((__m512i*)dp);
1316 __m512i t = _mm512_add_epi64(s1, s2);
1317 __m512i v = _mm512_sub_epi64(vb, t);
1318 __m512i w = _mm512_srai_epi64(v, e);
1319 d = _mm512_sub_epi64(d, w);
1320 _mm512_store_si512((__m512i*)dp, d);
1321 }
1322 else
1323 for (; i > 0; i -= 8, sp += 8, dp += 8)
1324 {
1325 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1326 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1327 __m512i d = _mm512_load_si512((__m512i*)dp);
1328 __m512i t = _mm512_add_epi64(s1, s2);
1329 __m512i v = _mm512_sub_epi64(vb, t);
1330 __m512i w = _mm512_srai_epi64(v, e);
1331 d = _mm512_sub_epi64(d, w);
1332 _mm512_store_si512((__m512i*)dp, d);
1333 }
1334 }
1335 else
1336 {
1337 // general case
1338 // 64bit multiplication is not supported in AVX512F + AVX512CD;
1339 // in particular, _mm256_mullo_epi64.
1340 if (ev)
1341 for (ui32 i = aug_width; i > 0; --i, sp++, dp++)
1342 *dp -= (b + a * (sp[-1] + sp[0])) >> e;
1343 else
1344 for (ui32 i = aug_width; i > 0; --i, sp++, dp++)
1345 *dp -= (b + a * (sp[0] + sp[1])) >> e;
1346 }
1347
1348 // This can only be used if you have AVX512DQ
1349 // {
1350 // // general case
1351 // __m512i va = _mm512_set1_epi64(a);
1352 // int i = (int)aug_width;
1353 // if (ev)
1354 // for (; i > 0; i -= 8, sp += 8, dp += 8)
1355 // {
1356 // __m512i s1 = _mm512_load_si512((__m512i*)sp);
1357 // __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1358 // __m512i d = _mm512_load_si512((__m512i*)dp);
1359 // __m512i t = _mm512_add_epi64(s1, s2);
1360 // __m512i u = _mm512_mullo_epi64(va, t);
1361 // __m512i v = _mm512_add_epi64(vb, u);
1362 // __m512i w = _mm512_srai_epi64(v, e);
1363 // d = _mm512_sub_epi64(d, w);
1364 // _mm512_store_si512((__m512i*)dp, d);
1365 // }
1366 // else
1367 // for (; i > 0; i -= 8, sp += 8, dp += 8)
1368 // {
1369 // __m512i s1 = _mm512_load_si512((__m512i*)sp);
1370 // __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1371 // __m512i d = _mm512_load_si512((__m512i*)dp);
1372 // __m512i t = _mm512_add_epi64(s1, s2);
1373 // __m512i u = _mm512_mullo_epi64(va, t);
1374 // __m512i v = _mm512_add_epi64(vb, u);
1375 // __m512i w = _mm512_srai_epi64(v, e);
1376 // d = _mm512_sub_epi64(d, w);
1377 // _mm512_store_si512((__m512i*)dp, d);
1378 // }
1379 // }
1380
1381 // swap buffers
1382 si64* t = aug; aug = oth; oth = t;
1383 ev = !ev;
1384 ui32 w = aug_width; aug_width = oth_width; oth_width = w;
1385 }
1386
1387 // combine both lsrc and hsrc into dst
1388 {
1389 double* dp = (double*)(dst->p);
1390 double* spl = (double*)(even ? lsrc->p : hsrc->p);
1391 double* sph = (double*)(even ? hsrc->p : lsrc->p);
1392 int w = (int)width;
1393 avx512_interleave64(dp, spl, sph, w);
1394 }
1395 }
1396 else {
1397 if (even)
1398 dst->i64[0] = lsrc->i64[0];
1399 else
1400 dst->i64[0] = hsrc->i64[0] >> 1;
1401 }
1402 }
1403
1405 void avx512_rev_horz_syn(const param_atk* atk, const line_buf* dst,
1406 const line_buf* lsrc, const line_buf* hsrc,
1407 ui32 width, bool even)
1408 {
1409 if (dst->flags & line_buf::LFT_32BIT)
1410 {
1411 assert((lsrc == NULL || lsrc->flags & line_buf::LFT_32BIT) &&
1412 (hsrc == NULL || hsrc->flags & line_buf::LFT_32BIT));
1413 avx512_rev_horz_syn32(atk, dst, lsrc, hsrc, width, even);
1414 }
1415 else
1416 {
1417 assert((dst == NULL || dst->flags & line_buf::LFT_64BIT) &&
1418 (lsrc == NULL || lsrc->flags & line_buf::LFT_64BIT) &&
1419 (hsrc == NULL || hsrc->flags & line_buf::LFT_64BIT));
1420 avx512_rev_horz_syn64(atk, dst, lsrc, hsrc, width, even);
1421 }
1422 }
1423
1424 } // !local
1425} // !ojph
float * f32
Definition ojph_mem.h:162
static void avx512_deinterleave32(float *dpl, float *dph, float *sp, int width)
void avx512_irv_vert_step(const lifting_step *s, const line_buf *sig, const line_buf *other, const line_buf *aug, ui32 repeat, bool synthesis)
void avx512_rev_horz_syn(const param_atk *atk, const line_buf *dst, const line_buf *lsrc, const line_buf *hsrc, ui32 width, bool even)
static void avx512_interleave64(double *dp, double *spl, double *sph, int width)
void avx512_rev_vert_step(const lifting_step *s, const line_buf *sig, const line_buf *other, const line_buf *aug, ui32 repeat, bool synthesis)
void avx512_rev_vert_step64(const lifting_step *s, const line_buf *sig, const line_buf *other, const line_buf *aug, ui32 repeat, bool synthesis)
void avx512_rev_vert_step32(const lifting_step *s, const line_buf *sig, const line_buf *other, const line_buf *aug, ui32 repeat, bool synthesis)
static void avx512_interleave32(float *dp, float *spl, float *sph, int width)
void avx512_irv_horz_ana(const param_atk *atk, const line_buf *ldst, const line_buf *hdst, const line_buf *src, ui32 width, bool even)
void avx512_rev_horz_ana32(const param_atk *atk, const line_buf *ldst, const line_buf *hdst, const line_buf *src, ui32 width, bool even)
void avx512_irv_vert_times_K(float K, const line_buf *aug, ui32 repeat)
void avx512_irv_horz_syn(const param_atk *atk, const line_buf *dst, const line_buf *lsrc, const line_buf *hsrc, ui32 width, bool even)
void avx512_rev_horz_ana64(const param_atk *atk, const line_buf *ldst, const line_buf *hdst, const line_buf *src, ui32 width, bool even)
void avx512_rev_horz_syn64(const param_atk *atk, const line_buf *dst, const line_buf *lsrc, const line_buf *hsrc, ui32 width, bool even)
void avx512_rev_horz_syn32(const param_atk *atk, const line_buf *dst, const line_buf *lsrc, const line_buf *hsrc, ui32 width, bool even)
static void avx512_multiply_const(float *p, float f, int width)
void avx512_rev_horz_ana(const param_atk *atk, const line_buf *ldst, const line_buf *hdst, const line_buf *src, ui32 width, bool even)
static void avx512_deinterleave64(double *dpl, double *dph, double *sp, int width)
int64_t si64
Definition ojph_defs.h:57
int32_t si32
Definition ojph_defs.h:55
uint32_t ui32
Definition ojph_defs.h:54
uint8_t ui8
Definition ojph_defs.h:50
const lifting_step * get_step(ui32 s) const