CMSIS-NN: Add MVE support to int16 for fully connected (#1549)

diff --git a/ARM.CMSIS.pdsc b/ARM.CMSIS.pdsc
index a43c1ee..b8c6759 100644
--- a/ARM.CMSIS.pdsc
+++ b/ARM.CMSIS.pdsc
@@ -17,6 +17,7 @@
        - Support for DSP extension optimization for int16 depthwise_conv
        - Support for MVEI extension optimization for int16 depthwise_conv
        - Support for MVEI extension optimization for int16 max and average pooling
+       - Support for MVEI extension optimization for int16 fully connected
     </release>
     <release version="5.9.0" date="2022-05-02">
       CMSIS-Core(M): 5.6.0
diff --git a/CMSIS/DoxyGen/NN/src/history.txt b/CMSIS/DoxyGen/NN/src/history.txt
index 5e1a38e..988397f 100644
--- a/CMSIS/DoxyGen/NN/src/history.txt
+++ b/CMSIS/DoxyGen/NN/src/history.txt
@@ -15,6 +15,7 @@
       <li> Added support for DSP extension optimization for int16 depthwise_conv </li>
       <li> Added support for MVEI extension optimization for int16 depthwise_conv </li>
       <li> Added support for MVEI extension optimization for int16 max & average pooling </li>
+      <li> Added support for MVEI extension optimization for int16 fully fully_connected </li>
       </ul>
     </td>
   </tr>
diff --git a/CMSIS/NN/README.md b/CMSIS/NN/README.md
index 900094c..0d683dc 100644
--- a/CMSIS/NN/README.md
+++ b/CMSIS/NN/README.md
@@ -41,7 +41,7 @@
 ||arm_depthwise_conv_fast_s16() | DEPTHWISE_CONV | Yes | Yes. Refer to API for details | Yes | Yes ||
 |[Fully Connected](https://arm-software.github.io/CMSIS_5/NN/html/group__FC.html)||||| |  | |
 ||arm_fully_connected_s8() |FULLY CONNECTED & <br/> MAT MUL  | None | No | Yes | Yes | |
-||arm_fully_connected_s16() |FULLY CONNECTED & <br/> MAT MUL  | None | No | Yes | No | |
+||arm_fully_connected_s16() |FULLY CONNECTED & <br/> MAT MUL  | None | No | Yes | Yes | |
 |[Pooling](https://arm-software.github.io/CMSIS_5/NN/html/group__Pooling.html)||||| |  ||
 || arm_avgpool_s8() | AVERAGE POOL | None | input_ch * 4<br/>(DSP only) | Yes| Yes| Best case is when channels are multiple of 4 or <br/> at the least >= 4 |
 || arm_avgpool_s16() | AVERAGE POOL | None | input_ch * 4<br/>(DSP only) | Yes| Yes| Best case is when channels are multiple of 4 or <br/> at the least >= 4 |
diff --git a/CMSIS/NN/Source/NNSupportFunctions/arm_nn_vec_mat_mult_t_s16.c b/CMSIS/NN/Source/NNSupportFunctions/arm_nn_vec_mat_mult_t_s16.c
index 4ac9e45..64f61d1 100644
--- a/CMSIS/NN/Source/NNSupportFunctions/arm_nn_vec_mat_mult_t_s16.c
+++ b/CMSIS/NN/Source/NNSupportFunctions/arm_nn_vec_mat_mult_t_s16.c
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2020-2022 Arm Limited or its affiliates.
+ * SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -21,14 +21,16 @@
  * Title:        arm_nn_vec_mat_mult_t_s16
  * Description:  s16 vector by matrix (transposed) multiplication
  *
- * $Date:        19 April 2022
- * $Revision:    V.2.0.0
+ * $Date:        11 August 2022
+ * $Revision:    V.2.1.0
  *
  * Target Processor:  Cortex-M
  *
  * -------------------------------------------------------------------- */
 
 #include "arm_nnsupportfunctions.h"
+#define MAX_COL_COUNT (512)
+
 /**
  * @ingroup groupSupport
  */
@@ -55,18 +57,170 @@
                                               const int32_t activation_min,
                                               const int32_t activation_max)
 {
-#if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
-    const int32_t row_loop_cnt = rhs_rows / 2;
+#if defined(ARM_MATH_DSP)
 
     int32_t rhs_cols_fast = rhs_cols;
 
-    if (rhs_cols > 512)
+    if (rhs_cols > MAX_COL_COUNT)
     {
-        rhs_cols_fast = 512;
+        rhs_cols_fast = MAX_COL_COUNT;
     }
 
+#if defined(ARM_MATH_MVEI)
+    int32_t row_loop_cnt = rhs_rows / 4;
+    int32_t col_loop_cnt = (rhs_cols_fast + 7) / 8;
+
+    for (int32_t i_row_loop_count = 0; i_row_loop_count < row_loop_cnt; i_row_loop_count++)
+    {
+        int32_t col_cnt = rhs_cols_fast;
+
+        const int16_t *lhs_ptr = lhs;
+        const int8_t *rhs_ptr_0 = rhs;
+        const int8_t *rhs_ptr_1 = rhs + rhs_cols;
+        const int8_t *rhs_ptr_2 = rhs + rhs_cols * 2;
+        const int8_t *rhs_ptr_3 = rhs + rhs_cols * 3;
+
+        int32_t result_0 = 0;
+        int32_t result_1 = 0;
+        int32_t result_2 = 0;
+        int32_t result_3 = 0;
+
+        for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
+        {
+            mve_pred16_t pred = vctp16q(col_cnt);
+            col_cnt -= 8;
+
+            int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
+
+            int16x8_t rhs_input_0 = vldrbq_z_s16(rhs_ptr_0, pred);
+            int16x8_t rhs_input_1 = vldrbq_z_s16(rhs_ptr_1, pred);
+            int16x8_t rhs_input_2 = vldrbq_z_s16(rhs_ptr_2, pred);
+            int16x8_t rhs_input_3 = vldrbq_z_s16(rhs_ptr_3, pred);
+
+            result_0 = vmladavaq_s16(result_0, lhs_input, rhs_input_0);
+            result_1 = vmladavaq_s16(result_1, lhs_input, rhs_input_1);
+            result_2 = vmladavaq_s16(result_2, lhs_input, rhs_input_2);
+            result_3 = vmladavaq_s16(result_3, lhs_input, rhs_input_3);
+
+            lhs_ptr += 8;
+
+            rhs_ptr_0 += 8;
+            rhs_ptr_1 += 8;
+            rhs_ptr_2 += 8;
+            rhs_ptr_3 += 8;
+        }
+
+        int64_t result_64_0 = result_0;
+        int64_t result_64_1 = result_1;
+        int64_t result_64_2 = result_2;
+        int64_t result_64_3 = result_3;
+
+        if (rhs_cols > MAX_COL_COUNT)
+        {
+            for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
+            {
+                const int16_t lhs_temp = *lhs_ptr++;
+
+                result_64_0 += *rhs_ptr_0++ * lhs_temp;
+                result_64_1 += *rhs_ptr_1++ * lhs_temp;
+                result_64_2 += *rhs_ptr_2++ * lhs_temp;
+                result_64_3 += *rhs_ptr_3++ * lhs_temp;
+            }
+        }
+
+        if (bias)
+        {
+            result_64_0 += *bias++;
+            result_64_1 += *bias++;
+            result_64_2 += *bias++;
+            result_64_3 += *bias++;
+        }
+
+        int32_t tmp;
+        tmp = arm_nn_requantize_s64(result_64_0, dst_multiplier, dst_shift);
+        tmp = MAX(tmp, activation_min);
+        tmp = MIN(tmp, activation_max);
+        *dst++ = (q15_t)tmp;
+
+        tmp = 0;
+        tmp = arm_nn_requantize_s64(result_64_1, dst_multiplier, dst_shift);
+        tmp = MAX(tmp, activation_min);
+        tmp = MIN(tmp, activation_max);
+        *dst++ = (q15_t)tmp;
+
+        tmp = 0;
+        tmp = arm_nn_requantize_s64(result_64_2, dst_multiplier, dst_shift);
+        tmp = MAX(tmp, activation_min);
+        tmp = MIN(tmp, activation_max);
+        *dst++ = (q15_t)tmp;
+
+        tmp = 0;
+        tmp = arm_nn_requantize_s64(result_64_3, dst_multiplier, dst_shift);
+        tmp = MAX(tmp, activation_min);
+        tmp = MIN(tmp, activation_max);
+        *dst++ = (q15_t)tmp;
+
+        rhs += 4 * rhs_cols;
+    }
+
+    for (int8_t rows_left = rhs_rows & 0x3; rows_left > 0; rows_left--)
+    {
+        int32_t result = 0;
+
+        col_loop_cnt = (rhs_cols_fast + 7) / 8;
+
+        const int16_t *lhs_ptr = lhs;
+        const int8_t *rhs_ptr = rhs;
+
+        int32_t col_cnt = (int32_t)rhs_cols_fast;
+
+        for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
+        {
+            mve_pred16_t pred = vctp16q(col_cnt);
+            col_cnt -= 8;
+
+            int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
+            int16x8_t rhs_input = vldrbq_z_s16(rhs_ptr, pred);
+
+            result = vmladavaq_p_s16(result, lhs_input, rhs_input, pred);
+
+            lhs_ptr += 8;
+            rhs_ptr += 8;
+        }
+
+        int64_t result_64 = result;
+
+        if (bias)
+        {
+            result_64 += *bias++;
+        }
+
+        if (rhs_cols > MAX_COL_COUNT)
+        {
+            for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
+            {
+                const int16_t lhs_temp = *lhs_ptr++;
+
+                result_64 += *rhs_ptr++ * lhs_temp;
+            }
+        }
+
+        int32_t tmp = 0;
+        tmp = arm_nn_requantize_s64(result_64, dst_multiplier, dst_shift);
+        tmp = MAX(tmp, activation_min);
+        tmp = MIN(tmp, activation_max);
+        *dst++ = (q15_t)tmp;
+
+        rhs += rhs_cols;
+    }
+
+#else // ARM_MATH_MVEI
+
+    const int32_t row_loop_cnt = rhs_rows / 2;
+
     for (int32_t i = 0; i < row_loop_cnt; i++)
     {
+
         q63_t acc_64_0 = 0;
         q63_t acc_64_1 = 0;
         int32_t acc_0 = 0;
@@ -82,6 +236,7 @@
         for (int j = col_loop_cnt; j != 0; j--)
         {
             int32_t ker_0, ker_1, vec_part_0, vec_part_1;
+
             vec_part_0 = arm_nn_read_q15x2_ia(&lhs_vec);
             vec_part_1 = arm_nn_read_q15x2_ia(&lhs_vec);
 
@@ -115,6 +270,7 @@
             acc_64_1 += *bias++;
         }
         q31_t tmp;
+
         tmp = arm_nn_requantize_s64(acc_64_0, dst_multiplier, dst_shift);
         tmp = MAX(tmp, activation_min);
         tmp = MIN(tmp, activation_max);
@@ -168,7 +324,8 @@
         *dst++ = (q15_t)tmp;
     }
 
-#else
+#endif // ARM_MATH_MVEI
+#else  // ARM_MATH_DSP
     for (int i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows; i_row_loop_cnt++)
     {
         const q15_t *lhs_ptr = lhs;
@@ -176,10 +333,6 @@
 
         q63_t result = 0;
 
-        if (bias)
-        {
-            result = *bias++;
-        }
         for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
         {
             const q63_t rhs_value0 = (int8_t)*rhs_ptr_0;
@@ -191,6 +344,10 @@
             ++lhs_ptr;
         }
 
+        if (bias)
+        {
+            result += *bias++;
+        }
         // Quantize down
         result = arm_nn_requantize_s64(result, dst_multiplier, dst_shift);
 
@@ -201,7 +358,7 @@
         *dst++ = (q15_t)result;
         rhs += rhs_cols;
     }
-#endif
+#endif // ARM_MATH_DSP
 
     return ARM_CMSIS_NN_SUCCESS;
 }