Update Linux to v5.4.2

Change-Id: Idf6911045d9d382da2cfe01b1edff026404ac8fd
diff --git a/arch/mips/net/ebpf_jit.c b/arch/mips/net/ebpf_jit.c
index aeb7b1b..46b7675 100644
--- a/arch/mips/net/ebpf_jit.c
+++ b/arch/mips/net/ebpf_jit.c
@@ -1,3 +1,4 @@
+// SPDX-License-Identifier: GPL-2.0-only
 /*
  * Just-In-Time compiler for eBPF filters on MIPS
  *
@@ -7,10 +8,6 @@
  *
  * Copyright (c) 2014 Imagination Technologies Ltd.
  * Author: Markos Chandras <markos.chandras@imgtec.com>
- *
- * This program is free software; you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation; version 2 of the License.
  */
 
 #include <linux/bitops.h>
@@ -22,6 +19,7 @@
 #include <asm/byteorder.h>
 #include <asm/cacheflush.h>
 #include <asm/cpu-features.h>
+#include <asm/isa-rev.h>
 #include <asm/uasm.h>
 
 /* Registers used by JIT */
@@ -79,8 +77,6 @@
 	REG_64BIT_32BIT,
 	/* 32-bit compatible, need truncation for 64-bit ops. */
 	REG_32BIT,
-	/* 32-bit zero extended. */
-	REG_32BIT_ZERO_EX,
 	/* 32-bit no sign/zero extension needed. */
 	REG_32BIT_POS
 };
@@ -127,15 +123,21 @@
 }
 
 /* Simply emit the instruction if the JIT memory space has been allocated */
-#define emit_instr(ctx, func, ...)			\
-do {							\
-	if ((ctx)->target != NULL) {			\
-		u32 *p = &(ctx)->target[ctx->idx];	\
-		uasm_i_##func(&p, ##__VA_ARGS__);	\
-	}						\
-	(ctx)->idx++;					\
+#define emit_instr_long(ctx, func64, func32, ...)		\
+do {								\
+	if ((ctx)->target != NULL) {				\
+		u32 *p = &(ctx)->target[ctx->idx];		\
+		if (IS_ENABLED(CONFIG_64BIT))			\
+			uasm_i_##func64(&p, ##__VA_ARGS__);	\
+		else						\
+			uasm_i_##func32(&p, ##__VA_ARGS__);	\
+	}							\
+	(ctx)->idx++;						\
 } while (0)
 
+#define emit_instr(ctx, func, ...)				\
+	emit_instr_long(ctx, func, func, ##__VA_ARGS__)
+
 static unsigned int j_target(struct jit_ctx *ctx, int target_idx)
 {
 	unsigned long target_va, base_va;
@@ -188,8 +190,9 @@
  * separate frame pointer, so BPF_REG_10 relative accesses are
  * adjusted to be $sp relative.
  */
-int ebpf_to_mips_reg(struct jit_ctx *ctx, const struct bpf_insn *insn,
-		     enum which_ebpf_reg w)
+static int ebpf_to_mips_reg(struct jit_ctx *ctx,
+			    const struct bpf_insn *insn,
+			    enum which_ebpf_reg w)
 {
 	int ebpf_reg = (w == src_reg || w == src_reg_no_fp) ?
 		insn->src_reg : insn->dst_reg;
@@ -275,17 +278,17 @@
 		 * If RA we are doing a function call and may need
 		 * extra 8-byte tmp area.
 		 */
-		stack_adjust += 16;
+		stack_adjust += 2 * sizeof(long);
 	if (ctx->flags & EBPF_SAVE_S0)
-		stack_adjust += 8;
+		stack_adjust += sizeof(long);
 	if (ctx->flags & EBPF_SAVE_S1)
-		stack_adjust += 8;
+		stack_adjust += sizeof(long);
 	if (ctx->flags & EBPF_SAVE_S2)
-		stack_adjust += 8;
+		stack_adjust += sizeof(long);
 	if (ctx->flags & EBPF_SAVE_S3)
-		stack_adjust += 8;
+		stack_adjust += sizeof(long);
 	if (ctx->flags & EBPF_SAVE_S4)
-		stack_adjust += 8;
+		stack_adjust += sizeof(long);
 
 	BUILD_BUG_ON(MAX_BPF_STACK & 7);
 	locals_size = (ctx->flags & EBPF_SEEN_FP) ? MAX_BPF_STACK : 0;
@@ -299,41 +302,49 @@
 	 * On tail call we skip this instruction, and the TCC is
 	 * passed in $v1 from the caller.
 	 */
-	emit_instr(ctx, daddiu, MIPS_R_V1, MIPS_R_ZERO, MAX_TAIL_CALL_CNT);
+	emit_instr(ctx, addiu, MIPS_R_V1, MIPS_R_ZERO, MAX_TAIL_CALL_CNT);
 	if (stack_adjust)
-		emit_instr(ctx, daddiu, MIPS_R_SP, MIPS_R_SP, -stack_adjust);
+		emit_instr_long(ctx, daddiu, addiu,
+					MIPS_R_SP, MIPS_R_SP, -stack_adjust);
 	else
 		return 0;
 
-	store_offset = stack_adjust - 8;
+	store_offset = stack_adjust - sizeof(long);
 
 	if (ctx->flags & EBPF_SAVE_RA) {
-		emit_instr(ctx, sd, MIPS_R_RA, store_offset, MIPS_R_SP);
-		store_offset -= 8;
+		emit_instr_long(ctx, sd, sw,
+					MIPS_R_RA, store_offset, MIPS_R_SP);
+		store_offset -= sizeof(long);
 	}
 	if (ctx->flags & EBPF_SAVE_S0) {
-		emit_instr(ctx, sd, MIPS_R_S0, store_offset, MIPS_R_SP);
-		store_offset -= 8;
+		emit_instr_long(ctx, sd, sw,
+					MIPS_R_S0, store_offset, MIPS_R_SP);
+		store_offset -= sizeof(long);
 	}
 	if (ctx->flags & EBPF_SAVE_S1) {
-		emit_instr(ctx, sd, MIPS_R_S1, store_offset, MIPS_R_SP);
-		store_offset -= 8;
+		emit_instr_long(ctx, sd, sw,
+					MIPS_R_S1, store_offset, MIPS_R_SP);
+		store_offset -= sizeof(long);
 	}
 	if (ctx->flags & EBPF_SAVE_S2) {
-		emit_instr(ctx, sd, MIPS_R_S2, store_offset, MIPS_R_SP);
-		store_offset -= 8;
+		emit_instr_long(ctx, sd, sw,
+					MIPS_R_S2, store_offset, MIPS_R_SP);
+		store_offset -= sizeof(long);
 	}
 	if (ctx->flags & EBPF_SAVE_S3) {
-		emit_instr(ctx, sd, MIPS_R_S3, store_offset, MIPS_R_SP);
-		store_offset -= 8;
+		emit_instr_long(ctx, sd, sw,
+					MIPS_R_S3, store_offset, MIPS_R_SP);
+		store_offset -= sizeof(long);
 	}
 	if (ctx->flags & EBPF_SAVE_S4) {
-		emit_instr(ctx, sd, MIPS_R_S4, store_offset, MIPS_R_SP);
-		store_offset -= 8;
+		emit_instr_long(ctx, sd, sw,
+					MIPS_R_S4, store_offset, MIPS_R_SP);
+		store_offset -= sizeof(long);
 	}
 
 	if ((ctx->flags & EBPF_SEEN_TC) && !(ctx->flags & EBPF_TCC_IN_V1))
-		emit_instr(ctx, daddu, MIPS_R_S4, MIPS_R_V1, MIPS_R_ZERO);
+		emit_instr_long(ctx, daddu, addu,
+					MIPS_R_S4, MIPS_R_V1, MIPS_R_ZERO);
 
 	return 0;
 }
@@ -342,42 +353,52 @@
 {
 	const struct bpf_prog *prog = ctx->skf;
 	int stack_adjust = ctx->stack_size;
-	int store_offset = stack_adjust - 8;
+	int store_offset = stack_adjust - sizeof(long);
+	enum reg_val_type td;
 	int r0 = MIPS_R_V0;
 
-	if (dest_reg == MIPS_R_RA &&
-	    get_reg_val_type(ctx, prog->len, BPF_REG_0) == REG_32BIT_ZERO_EX)
+	if (dest_reg == MIPS_R_RA) {
 		/* Don't let zero extended value escape. */
-		emit_instr(ctx, sll, r0, r0, 0);
+		td = get_reg_val_type(ctx, prog->len, BPF_REG_0);
+		if (td == REG_64BIT)
+			emit_instr(ctx, sll, r0, r0, 0);
+	}
 
 	if (ctx->flags & EBPF_SAVE_RA) {
-		emit_instr(ctx, ld, MIPS_R_RA, store_offset, MIPS_R_SP);
-		store_offset -= 8;
+		emit_instr_long(ctx, ld, lw,
+					MIPS_R_RA, store_offset, MIPS_R_SP);
+		store_offset -= sizeof(long);
 	}
 	if (ctx->flags & EBPF_SAVE_S0) {
-		emit_instr(ctx, ld, MIPS_R_S0, store_offset, MIPS_R_SP);
-		store_offset -= 8;
+		emit_instr_long(ctx, ld, lw,
+					MIPS_R_S0, store_offset, MIPS_R_SP);
+		store_offset -= sizeof(long);
 	}
 	if (ctx->flags & EBPF_SAVE_S1) {
-		emit_instr(ctx, ld, MIPS_R_S1, store_offset, MIPS_R_SP);
-		store_offset -= 8;
+		emit_instr_long(ctx, ld, lw,
+					MIPS_R_S1, store_offset, MIPS_R_SP);
+		store_offset -= sizeof(long);
 	}
 	if (ctx->flags & EBPF_SAVE_S2) {
-		emit_instr(ctx, ld, MIPS_R_S2, store_offset, MIPS_R_SP);
-		store_offset -= 8;
+		emit_instr_long(ctx, ld, lw,
+				MIPS_R_S2, store_offset, MIPS_R_SP);
+		store_offset -= sizeof(long);
 	}
 	if (ctx->flags & EBPF_SAVE_S3) {
-		emit_instr(ctx, ld, MIPS_R_S3, store_offset, MIPS_R_SP);
-		store_offset -= 8;
+		emit_instr_long(ctx, ld, lw,
+					MIPS_R_S3, store_offset, MIPS_R_SP);
+		store_offset -= sizeof(long);
 	}
 	if (ctx->flags & EBPF_SAVE_S4) {
-		emit_instr(ctx, ld, MIPS_R_S4, store_offset, MIPS_R_SP);
-		store_offset -= 8;
+		emit_instr_long(ctx, ld, lw,
+					MIPS_R_S4, store_offset, MIPS_R_SP);
+		store_offset -= sizeof(long);
 	}
 	emit_instr(ctx, jr, dest_reg);
 
 	if (stack_adjust)
-		emit_instr(ctx, daddiu, MIPS_R_SP, MIPS_R_SP, stack_adjust);
+		emit_instr_long(ctx, daddiu, addiu,
+					MIPS_R_SP, MIPS_R_SP, stack_adjust);
 	else
 		emit_instr(ctx, nop);
 
@@ -644,6 +665,10 @@
 	s64 t64s;
 	int bpf_op = BPF_OP(insn->code);
 
+	if (IS_ENABLED(CONFIG_32BIT) && ((BPF_CLASS(insn->code) == BPF_ALU64)
+						|| (bpf_op == BPF_DW)))
+		return -EINVAL;
+
 	switch (insn->code) {
 	case BPF_ALU64 | BPF_ADD | BPF_K: /* ALU64_IMM */
 	case BPF_ALU64 | BPF_SUB | BPF_K: /* ALU64_IMM */
@@ -676,8 +701,12 @@
 		if (insn->imm == 1) /* Mult by 1 is a nop */
 			break;
 		gen_imm_to_reg(insn, MIPS_R_AT, ctx);
-		emit_instr(ctx, dmultu, MIPS_R_AT, dst);
-		emit_instr(ctx, mflo, dst);
+		if (MIPS_ISA_REV >= 6) {
+			emit_instr(ctx, dmulu, dst, dst, MIPS_R_AT);
+		} else {
+			emit_instr(ctx, dmultu, MIPS_R_AT, dst);
+			emit_instr(ctx, mflo, dst);
+		}
 		break;
 	case BPF_ALU64 | BPF_NEG | BPF_K: /* ALU64_IMM */
 		dst = ebpf_to_mips_reg(ctx, insn, dst_reg);
@@ -692,22 +721,26 @@
 		if (dst < 0)
 			return dst;
 		td = get_reg_val_type(ctx, this_idx, insn->dst_reg);
-		if (td == REG_64BIT || td == REG_32BIT_ZERO_EX) {
+		if (td == REG_64BIT) {
 			/* sign extend */
 			emit_instr(ctx, sll, dst, dst, 0);
 		}
 		if (insn->imm == 1) /* Mult by 1 is a nop */
 			break;
 		gen_imm_to_reg(insn, MIPS_R_AT, ctx);
-		emit_instr(ctx, multu, dst, MIPS_R_AT);
-		emit_instr(ctx, mflo, dst);
+		if (MIPS_ISA_REV >= 6) {
+			emit_instr(ctx, mulu, dst, dst, MIPS_R_AT);
+		} else {
+			emit_instr(ctx, multu, dst, MIPS_R_AT);
+			emit_instr(ctx, mflo, dst);
+		}
 		break;
 	case BPF_ALU | BPF_NEG | BPF_K: /* ALU_IMM */
 		dst = ebpf_to_mips_reg(ctx, insn, dst_reg);
 		if (dst < 0)
 			return dst;
 		td = get_reg_val_type(ctx, this_idx, insn->dst_reg);
-		if (td == REG_64BIT || td == REG_32BIT_ZERO_EX) {
+		if (td == REG_64BIT) {
 			/* sign extend */
 			emit_instr(ctx, sll, dst, dst, 0);
 		}
@@ -721,7 +754,7 @@
 		if (dst < 0)
 			return dst;
 		td = get_reg_val_type(ctx, this_idx, insn->dst_reg);
-		if (td == REG_64BIT || td == REG_32BIT_ZERO_EX)
+		if (td == REG_64BIT)
 			/* sign extend */
 			emit_instr(ctx, sll, dst, dst, 0);
 		if (insn->imm == 1) {
@@ -731,6 +764,13 @@
 			break;
 		}
 		gen_imm_to_reg(insn, MIPS_R_AT, ctx);
+		if (MIPS_ISA_REV >= 6) {
+			if (bpf_op == BPF_DIV)
+				emit_instr(ctx, divu_r6, dst, dst, MIPS_R_AT);
+			else
+				emit_instr(ctx, modu, dst, dst, MIPS_R_AT);
+			break;
+		}
 		emit_instr(ctx, divu, dst, MIPS_R_AT);
 		if (bpf_op == BPF_DIV)
 			emit_instr(ctx, mflo, dst);
@@ -753,6 +793,13 @@
 			break;
 		}
 		gen_imm_to_reg(insn, MIPS_R_AT, ctx);
+		if (MIPS_ISA_REV >= 6) {
+			if (bpf_op == BPF_DIV)
+				emit_instr(ctx, ddivu_r6, dst, dst, MIPS_R_AT);
+			else
+				emit_instr(ctx, modu, dst, dst, MIPS_R_AT);
+			break;
+		}
 		emit_instr(ctx, ddivu, dst, MIPS_R_AT);
 		if (bpf_op == BPF_DIV)
 			emit_instr(ctx, mflo, dst);
@@ -818,11 +865,23 @@
 			emit_instr(ctx, and, dst, dst, src);
 			break;
 		case BPF_MUL:
-			emit_instr(ctx, dmultu, dst, src);
-			emit_instr(ctx, mflo, dst);
+			if (MIPS_ISA_REV >= 6) {
+				emit_instr(ctx, dmulu, dst, dst, src);
+			} else {
+				emit_instr(ctx, dmultu, dst, src);
+				emit_instr(ctx, mflo, dst);
+			}
 			break;
 		case BPF_DIV:
 		case BPF_MOD:
+			if (MIPS_ISA_REV >= 6) {
+				if (bpf_op == BPF_DIV)
+					emit_instr(ctx, ddivu_r6,
+							dst, dst, src);
+				else
+					emit_instr(ctx, modu, dst, dst, src);
+				break;
+			}
 			emit_instr(ctx, ddivu, dst, src);
 			if (bpf_op == BPF_DIV)
 				emit_instr(ctx, mflo, dst);
@@ -854,18 +913,19 @@
 	case BPF_ALU | BPF_MOD | BPF_X: /* ALU_REG */
 	case BPF_ALU | BPF_LSH | BPF_X: /* ALU_REG */
 	case BPF_ALU | BPF_RSH | BPF_X: /* ALU_REG */
+	case BPF_ALU | BPF_ARSH | BPF_X: /* ALU_REG */
 		src = ebpf_to_mips_reg(ctx, insn, src_reg_no_fp);
 		dst = ebpf_to_mips_reg(ctx, insn, dst_reg);
 		if (src < 0 || dst < 0)
 			return -EINVAL;
 		td = get_reg_val_type(ctx, this_idx, insn->dst_reg);
-		if (td == REG_64BIT || td == REG_32BIT_ZERO_EX) {
+		if (td == REG_64BIT) {
 			/* sign extend */
 			emit_instr(ctx, sll, dst, dst, 0);
 		}
 		did_move = false;
 		ts = get_reg_val_type(ctx, this_idx, insn->src_reg);
-		if (ts == REG_64BIT || ts == REG_32BIT_ZERO_EX) {
+		if (ts == REG_64BIT) {
 			int tmp_reg = MIPS_R_AT;
 
 			if (bpf_op == BPF_MOV) {
@@ -901,6 +961,13 @@
 			break;
 		case BPF_DIV:
 		case BPF_MOD:
+			if (MIPS_ISA_REV >= 6) {
+				if (bpf_op == BPF_DIV)
+					emit_instr(ctx, divu_r6, dst, dst, src);
+				else
+					emit_instr(ctx, modu, dst, dst, src);
+				break;
+			}
 			emit_instr(ctx, divu, dst, src);
 			if (bpf_op == BPF_DIV)
 				emit_instr(ctx, mflo, dst);
@@ -913,6 +980,9 @@
 		case BPF_RSH:
 			emit_instr(ctx, srlv, dst, dst, src);
 			break;
+		case BPF_ARSH:
+			emit_instr(ctx, srav, dst, dst, src);
+			break;
 		default:
 			pr_err("ALU_REG NOT HANDLED\n");
 			return -EINVAL;
@@ -1001,8 +1071,15 @@
 			emit_instr(ctx, dsubu, MIPS_R_T8, dst, src);
 			emit_instr(ctx, sltu, MIPS_R_AT, dst, src);
 			/* SP known to be non-zero, movz becomes boolean not */
-			emit_instr(ctx, movz, MIPS_R_T9, MIPS_R_SP, MIPS_R_T8);
-			emit_instr(ctx, movn, MIPS_R_T9, MIPS_R_ZERO, MIPS_R_T8);
+			if (MIPS_ISA_REV >= 6) {
+				emit_instr(ctx, seleqz, MIPS_R_T9,
+						MIPS_R_SP, MIPS_R_T8);
+			} else {
+				emit_instr(ctx, movz, MIPS_R_T9,
+						MIPS_R_SP, MIPS_R_T8);
+				emit_instr(ctx, movn, MIPS_R_T9,
+						MIPS_R_ZERO, MIPS_R_T8);
+			}
 			emit_instr(ctx, or, MIPS_R_AT, MIPS_R_T9, MIPS_R_AT);
 			cmp_eq = bpf_op == BPF_JGT;
 			dst = MIPS_R_AT;
@@ -1229,7 +1306,7 @@
 
 	case BPF_JMP | BPF_CALL:
 		ctx->flags |= EBPF_SAVE_RA;
-		t64s = (s64)insn->imm + (s64)__bpf_call_base;
+		t64s = (s64)insn->imm + (long)__bpf_call_base;
 		emit_const_to_reg(ctx, MIPS_R_T9, (u64)t64s);
 		emit_instr(ctx, jalr, MIPS_R_RA, MIPS_R_T9);
 		/* delay slot */
@@ -1250,8 +1327,7 @@
 		if (insn->imm == 64 && td == REG_32BIT)
 			emit_instr(ctx, dinsu, dst, MIPS_R_ZERO, 32, 32);
 
-		if (insn->imm != 64 &&
-		    (td == REG_64BIT || td == REG_32BIT_ZERO_EX)) {
+		if (insn->imm != 64 && td == REG_64BIT) {
 			/* sign extend */
 			emit_instr(ctx, sll, dst, dst, 0);
 		}
@@ -1362,6 +1438,17 @@
 		if (src < 0)
 			return src;
 		if (BPF_MODE(insn->code) == BPF_XADD) {
+			/*
+			 * If mem_off does not fit within the 9 bit ll/sc
+			 * instruction immediate field, use a temp reg.
+			 */
+			if (MIPS_ISA_REV >= 6 &&
+			    (mem_off >= BIT(8) || mem_off < -BIT(8))) {
+				emit_instr(ctx, daddiu, MIPS_R_T6,
+						dst, mem_off);
+				mem_off = 0;
+				dst = MIPS_R_T6;
+			}
 			switch (BPF_SIZE(insn->code)) {
 			case BPF_W:
 				if (get_reg_val_type(ctx, this_idx, insn->src_reg) == REG_32BIT) {
@@ -1716,7 +1803,7 @@
 	unsigned int image_size;
 	u8 *image_ptr;
 
-	if (!prog->jit_requested || !cpu_has_mips64r2)
+	if (!prog->jit_requested || MIPS_ISA_REV < 2)
 		return prog;
 
 	tmp = bpf_jit_blind_constants(prog);
@@ -1815,7 +1902,7 @@
 
 	/* Update the icache */
 	flush_icache_range((unsigned long)ctx.target,
-			   (unsigned long)(ctx.target + ctx.idx * sizeof(u32)));
+			   (unsigned long)&ctx.target[ctx.idx]);
 
 	if (bpf_jit_enable > 1)
 		/* Dump JIT code */