diff --git a/sysdeps/aarch64/fpu/sv_math.h b/sysdeps/aarch64/fpu/sv_math.h index fc8e690e80..c0ce6bb6d6 100644 --- a/sysdeps/aarch64/fpu/sv_math.h +++ b/sysdeps/aarch64/fpu/sv_math.h @@ -24,11 +24,29 @@ #include "vecmath_config.h" +#if !defined(__ARM_FEATURE_SVE_BITS) || __ARM_FEATURE_SVE_BITS == 0 +/* If not specified by -msve-vector-bits, assume maximum vector length. */ +# define SVE_VECTOR_BYTES 256 +#else +# define SVE_VECTOR_BYTES (__ARM_FEATURE_SVE_BITS / 8) +#endif +#define SVE_NUM_FLTS (SVE_VECTOR_BYTES / sizeof (float)) +#define SVE_NUM_DBLS (SVE_VECTOR_BYTES / sizeof (double)) +/* Predicate is stored as one bit per byte of VL so requires VL / 64 bytes. */ +#define SVE_NUM_PG_BYTES (SVE_VECTOR_BYTES / sizeof (uint64_t)) + #define SV_NAME_F1(fun) _ZGVsMxv_##fun##f #define SV_NAME_D1(fun) _ZGVsMxv_##fun #define SV_NAME_F2(fun) _ZGVsMxvv_##fun##f #define SV_NAME_D2(fun) _ZGVsMxvv_##fun +static inline void +svstr_p (uint8_t *dst, svbool_t p) +{ + /* Predicate STR does not currently have an intrinsic. */ + __asm__("str %0, [%x1]\n" : : "Upa"(p), "r"(dst) : "memory"); +} + /* Double precision. */ static inline svint64_t sv_s64 (int64_t x) @@ -51,33 +69,35 @@ sv_f64 (double x) static inline svfloat64_t sv_call_f64 (double (*f) (double), svfloat64_t x, svfloat64_t y, svbool_t cmp) { - svbool_t p = svpfirst (cmp, svpfalse ()); - while (svptest_any (cmp, p)) + double tmp[SVE_NUM_DBLS]; + uint8_t pg_bits[SVE_NUM_PG_BYTES]; + svstr_p (pg_bits, cmp); + svst1 (svptrue_b64 (), tmp, svsel (cmp, x, y)); + + for (int i = 0; i < svcntd (); i++) { - double elem = svclastb_n_f64 (p, 0, x); - elem = (*f) (elem); - svfloat64_t y2 = svdup_n_f64 (elem); - y = svsel_f64 (p, y2, y); - p = svpnext_b64 (cmp, p); + if (pg_bits[i] & 1) + tmp[i] = f (tmp[i]); } - return y; + return svld1 (svptrue_b64 (), tmp); } static inline svfloat64_t sv_call2_f64 (double (*f) (double, double), svfloat64_t x1, svfloat64_t x2, svfloat64_t y, svbool_t cmp) { - svbool_t p = svpfirst (cmp, svpfalse ()); - while (svptest_any (cmp, p)) + double tmp1[SVE_NUM_DBLS], tmp2[SVE_NUM_DBLS]; + uint8_t pg_bits[SVE_NUM_PG_BYTES]; + svstr_p (pg_bits, cmp); + svst1 (svptrue_b64 (), tmp1, svsel (cmp, x1, y)); + svst1 (cmp, tmp2, x2); + + for (int i = 0; i < svcntd (); i++) { - double elem1 = svclastb_n_f64 (p, 0, x1); - double elem2 = svclastb_n_f64 (p, 0, x2); - double ret = (*f) (elem1, elem2); - svfloat64_t y2 = svdup_n_f64 (ret); - y = svsel_f64 (p, y2, y); - p = svpnext_b64 (cmp, p); + if (pg_bits[i] & 1) + tmp1[i] = f (tmp1[i], tmp2[i]); } - return y; + return svld1 (svptrue_b64 (), tmp1); } static inline svuint64_t @@ -109,33 +129,40 @@ sv_f32 (float x) static inline svfloat32_t sv_call_f32 (float (*f) (float), svfloat32_t x, svfloat32_t y, svbool_t cmp) { - svbool_t p = svpfirst (cmp, svpfalse ()); - while (svptest_any (cmp, p)) + float tmp[SVE_NUM_FLTS]; + uint8_t pg_bits[SVE_NUM_PG_BYTES]; + svstr_p (pg_bits, cmp); + svst1 (svptrue_b32 (), tmp, svsel (cmp, x, y)); + + for (int i = 0; i < svcntd (); i++) { - float elem = svclastb_n_f32 (p, 0, x); - elem = f (elem); - svfloat32_t y2 = svdup_n_f32 (elem); - y = svsel_f32 (p, y2, y); - p = svpnext_b32 (cmp, p); + uint8_t p = pg_bits[i]; + if (p & 1) + tmp[i * 2] = f (tmp[i * 2]); + if (p & (1 << 4)) + tmp[i * 2 + 1] = f (tmp[i * 2 + 1]); } - return y; + return svld1 (svptrue_b32 (), tmp); } static inline svfloat32_t sv_call2_f32 (float (*f) (float, float), svfloat32_t x1, svfloat32_t x2, svfloat32_t y, svbool_t cmp) { - svbool_t p = svpfirst (cmp, svpfalse ()); - while (svptest_any (cmp, p)) - { - float elem1 = svclastb_n_f32 (p, 0, x1); - float elem2 = svclastb_n_f32 (p, 0, x2); - float ret = f (elem1, elem2); - svfloat32_t y2 = svdup_n_f32 (ret); - y = svsel_f32 (p, y2, y); - p = svpnext_b32 (cmp, p); - } - return y; -} + float tmp1[SVE_NUM_FLTS], tmp2[SVE_NUM_FLTS]; + uint8_t pg_bits[SVE_NUM_PG_BYTES]; + svstr_p (pg_bits, cmp); + svst1 (svptrue_b32 (), tmp1, svsel (cmp, x1, y)); + svst1 (cmp, tmp2, x2); + for (int i = 0; i < svcntd (); i++) + { + uint8_t p = pg_bits[i]; + if (p & 1) + tmp1[i * 2] = f (tmp1[i * 2], tmp2[i * 2]); + if (p & (1 << 4)) + tmp1[i * 2 + 1] = f (tmp1[i * 2 + 1], tmp2[i * 2 + 1]); + } + return svld1 (svptrue_b32 (), tmp1); +} #endif