]> xenbits.xensource.com Git - qemu-xen.git/commitdiff
target/arm: Implement bfloat16 matrix multiply accumulate
authorRichard Henderson <richard.henderson@linaro.org>
Tue, 25 May 2021 22:58:13 +0000 (15:58 -0700)
committerPeter Maydell <peter.maydell@linaro.org>
Thu, 3 Jun 2021 15:43:26 +0000 (16:43 +0100)
This is BFMMLA for both AArch64 AdvSIMD and SVE,
and VMMLA.BF16 for AArch32 NEON.

Reviewed-by: Peter Maydell <peter.maydell@linaro.org>
Signed-off-by: Richard Henderson <richard.henderson@linaro.org>
Message-id: 20210525225817.400336-9-richard.henderson@linaro.org
Signed-off-by: Peter Maydell <peter.maydell@linaro.org>
target/arm/helper.h
target/arm/neon-shared.decode
target/arm/sve.decode
target/arm/translate-a64.c
target/arm/translate-neon.c
target/arm/translate-sve.c
target/arm/vec_helper.c

index 376c1cef0f685ebed53e530e2ffeea6f6e012109..af75d7f25f24891c76a74d7503d0d7d392ecf8d2 100644 (file)
@@ -1007,6 +1007,9 @@ DEF_HELPER_FLAGS_5(gvec_bfdot, TCG_CALL_NO_RWG,
 DEF_HELPER_FLAGS_5(gvec_bfdot_idx, TCG_CALL_NO_RWG,
                    void, ptr, ptr, ptr, ptr, i32)
 
+DEF_HELPER_FLAGS_5(gvec_bfmmla, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, i32)
+
 #ifdef TARGET_AARCH64
 #include "helper-a64.h"
 #include "helper-sve.h"
index fa3cf14e3a6a03365ebf6e2a68ce6a8980acde5e..4e0a25d27c1c08c45aa51e3addf5fa73c5d60117 100644 (file)
@@ -67,6 +67,8 @@ VUMMLA         1111 1100 0.10 .... .... 1100 .1.1 .... \
                vm=%vm_dp vn=%vn_dp vd=%vd_dp
 VUSMMLA        1111 1100 1.10 .... .... 1100 .1.0 .... \
                vm=%vm_dp vn=%vn_dp vd=%vd_dp
+VMMLA_b16      1111 1100 0.00 .... .... 1100 .1.0 .... \
+               vm=%vm_dp vn=%vn_dp vd=%vd_dp
 
 VCMLA_scalar   1111 1110 0 . rot:2 .... .... 1000 . q:1 index:1 0 vm:4 \
                vn=%vn_dp vd=%vd_dp size=1
index 51f87e8937efd42ec0dd964d28b60c9c67bdec41..6c17898deed926fce9522b4c4455abe989b8a868 100644 (file)
@@ -1568,8 +1568,10 @@ SQRDCMLAH_zzzz  01000100 esz:2 0 rm:5 0011 rot:2 rn:5 rd:5  ra=%reg_movprfx
 USDOT_zzzz      01000100 .. 0 ..... 011 110 ..... .....  @rda_rn_rm
 
 ### SVE2 floating point matrix multiply accumulate
-
-FMMLA           01100100 .. 1 ..... 111001 ..... .....  @rda_rn_rm
+{
+  BFMMLA        01100100 01 1 ..... 111 001 ..... .....  @rda_rn_rm_e0
+  FMMLA         01100100 .. 1 ..... 111 001 ..... .....  @rda_rn_rm
+}
 
 ### SVE2 Memory Gather Load Group
 
index 71de75e568babab91963dca05ea6e65e4d423c5b..9ce2f5a7d435724170538d2612b4a5f695e0186d 100644 (file)
@@ -12235,6 +12235,13 @@ static void disas_simd_three_reg_same_extra(DisasContext *s, uint32_t insn)
         }
         feature = dc_isar_feature(aa64_fcma, s);
         break;
+    case 0x1d: /* BFMMLA */
+        if (size != MO_16 || !is_q) {
+            unallocated_encoding(s);
+            return;
+        }
+        feature = dc_isar_feature(aa64_bf16, s);
+        break;
     case 0x1f: /* BFDOT */
         switch (size) {
         case 1:
@@ -12328,6 +12335,9 @@ static void disas_simd_three_reg_same_extra(DisasContext *s, uint32_t insn)
         }
         return;
 
+    case 0xd: /* BFMMLA */
+        gen_gvec_op4_ool(s, is_q, rd, rn, rm, rd, 0, gen_helper_gvec_bfmmla);
+        return;
     case 0xf: /* BFDOT */
         switch (size) {
         case 1:
index 8099767792b68348ac15e5b2cea320470480decc..9d227a1e13dc827a30d438e8ad3e068f48bd2128 100644 (file)
@@ -4126,3 +4126,12 @@ static bool trans_VUSMMLA(DisasContext *s, arg_VUSMMLA *a)
     return do_neon_ddda(s, 7, a->vd, a->vn, a->vm, 0,
                         gen_helper_gvec_usmmla_b);
 }
+
+static bool trans_VMMLA_b16(DisasContext *s, arg_VMMLA_b16 *a)
+{
+    if (!dc_isar_feature(aa32_bf16, s)) {
+        return false;
+    }
+    return do_neon_ddda(s, 7, a->vd, a->vn, a->vm, 0,
+                        gen_helper_gvec_bfmmla);
+}
index 6f02030635734b14e44362380085c30c2da3ed8b..4f575dc3343a1c37b0ea7288ebf58f8323161522 100644 (file)
@@ -8677,3 +8677,15 @@ static bool trans_BFDOT_zzxz(DisasContext *s, arg_rrxr_esz *a)
     }
     return true;
 }
+
+static bool trans_BFMMLA(DisasContext *s, arg_rrrr_esz *a)
+{
+    if (!dc_isar_feature(aa64_sve_bf16, s)) {
+        return false;
+    }
+    if (sve_access_check(s)) {
+        gen_gvec_ool_zzzz(s, gen_helper_gvec_bfmmla,
+                          a->rd, a->rn, a->rm, a->ra, 0);
+    }
+    return true;
+}
index 74a497f38ca219d22ab5df88f26c8c51be5e954a..27e9bdd3299993c1a4fdef6b53da14cdff798e1a 100644 (file)
@@ -2385,7 +2385,7 @@ static void do_mmla_b(void *vd, void *vn, void *vm, void *va, uint32_t desc,
          * Process the entire segment at once, writing back the
          * results only after we've consumed all of the inputs.
          *
-         * Key to indicies by column:
+         * Key to indices by column:
          *          i   j                  i             j
          */
         sum0 = a[H4(0 + 0)];
@@ -2472,3 +2472,43 @@ void HELPER(gvec_bfdot_idx)(void *vd, void *vn, void *vm,
     }
     clear_tail(d, opr_sz, simd_maxsz(desc));
 }
+
+void HELPER(gvec_bfmmla)(void *vd, void *vn, void *vm, void *va, uint32_t desc)
+{
+    intptr_t s, opr_sz = simd_oprsz(desc);
+    float32 *d = vd, *a = va;
+    uint32_t *n = vn, *m = vm;
+
+    for (s = 0; s < opr_sz / 4; s += 4) {
+        float32 sum00, sum01, sum10, sum11;
+
+        /*
+         * Process the entire segment at once, writing back the
+         * results only after we've consumed all of the inputs.
+         *
+         * Key to indicies by column:
+         *               i   j           i   k             j   k
+         */
+        sum00 = a[s + H4(0 + 0)];
+        sum00 = bfdotadd(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)]);
+        sum00 = bfdotadd(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)]);
+
+        sum01 = a[s + H4(0 + 1)];
+        sum01 = bfdotadd(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)]);
+        sum01 = bfdotadd(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)]);
+
+        sum10 = a[s + H4(2 + 0)];
+        sum10 = bfdotadd(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)]);
+        sum10 = bfdotadd(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)]);
+
+        sum11 = a[s + H4(2 + 1)];
+        sum11 = bfdotadd(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)]);
+        sum11 = bfdotadd(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)]);
+
+        d[s + H4(0 + 0)] = sum00;
+        d[s + H4(0 + 1)] = sum01;
+        d[s + H4(2 + 0)] = sum10;
+        d[s + H4(2 + 1)] = sum11;
+    }
+    clear_tail(d, opr_sz, simd_maxsz(desc));
+}