Update Linux to v5.10.109

Sourced from [1]

[1] https://cdn.kernel.org/pub/linux/kernel/v5.x/linux-5.10.109.tar.xz

Change-Id: I19bca9fc6762d4e63bcf3e4cba88bbe560d9c76c
Signed-off-by: Olivier Deprez <olivier.deprez@arm.com>
diff --git a/tools/testing/selftests/vm/.gitignore b/tools/testing/selftests/vm/.gitignore
index 31b3c98..849e822 100644
--- a/tools/testing/selftests/vm/.gitignore
+++ b/tools/testing/selftests/vm/.gitignore
@@ -1,12 +1,16 @@
+# SPDX-License-Identifier: GPL-2.0-only
 hugepage-mmap
 hugepage-shm
+khugepaged
 map_hugetlb
 map_populate
 thuge-gen
 compaction_test
 mlock2-tests
+mremap_dontunmap
 on-fault-limit
 transhuge-stress
+protection_keys
 userfaultfd
 mlock-intersect-test
 mlock-random-test
@@ -14,3 +18,5 @@
 gup_benchmark
 va_128TBswitch
 map_fixed_noreplace
+write_to_hugetlbfs
+hmm-tests
diff --git a/tools/testing/selftests/vm/Makefile b/tools/testing/selftests/vm/Makefile
index 9534dc2..2cf32e6 100644
--- a/tools/testing/selftests/vm/Makefile
+++ b/tools/testing/selftests/vm/Makefile
@@ -1,10 +1,30 @@
 # SPDX-License-Identifier: GPL-2.0
 # Makefile for vm selftests
+uname_M := $(shell uname -m 2>/dev/null || echo not)
+MACHINE ?= $(shell echo $(uname_M) | sed -e 's/aarch64.*/arm64/' -e 's/ppc64.*/ppc64/')
+
+# Without this, failed build products remain, with up-to-date timestamps,
+# thus tricking Make (and you!) into believing that All Is Well, in subsequent
+# make invocations:
+.DELETE_ON_ERROR:
+
+# Avoid accidental wrong builds, due to built-in rules working just a little
+# bit too well--but not quite as well as required for our situation here.
+#
+# In other words, "make userfaultfd" is supposed to fail to build at all,
+# because this Makefile only supports either "make" (all), or "make /full/path".
+# However,  the built-in rules, if not suppressed, will pick up CFLAGS and the
+# initial LDLIBS (but not the target-specific LDLIBS, because those are only
+# set for the full path target!). This causes it to get pretty far into building
+# things despite using incorrect values such as an *occasionally* incomplete
+# LDLIBS.
+MAKEFLAGS += --no-builtin-rules
 
 CFLAGS = -Wall -I ../../../../usr/include $(EXTRA_CFLAGS)
 LDLIBS = -lrt
 TEST_GEN_FILES = compaction_test
 TEST_GEN_FILES += gup_benchmark
+TEST_GEN_FILES += hmm-tests
 TEST_GEN_FILES += hugepage-mmap
 TEST_GEN_FILES += hugepage-shm
 TEST_GEN_FILES += map_hugetlb
@@ -12,12 +32,46 @@
 TEST_GEN_FILES += map_populate
 TEST_GEN_FILES += mlock-random-test
 TEST_GEN_FILES += mlock2-tests
+TEST_GEN_FILES += mremap_dontunmap
 TEST_GEN_FILES += on-fault-limit
 TEST_GEN_FILES += thuge-gen
 TEST_GEN_FILES += transhuge-stress
 TEST_GEN_FILES += userfaultfd
+TEST_GEN_FILES += khugepaged
+
+ifeq ($(MACHINE),x86_64)
+CAN_BUILD_I386 := $(shell ./../x86/check_cc.sh $(CC) ../x86/trivial_32bit_program.c -m32)
+CAN_BUILD_X86_64 := $(shell ./../x86/check_cc.sh $(CC) ../x86/trivial_64bit_program.c)
+CAN_BUILD_WITH_NOPIE := $(shell ./../x86/check_cc.sh $(CC) ../x86/trivial_program.c -no-pie)
+
+TARGETS := protection_keys
+BINARIES_32 := $(TARGETS:%=%_32)
+BINARIES_64 := $(TARGETS:%=%_64)
+
+ifeq ($(CAN_BUILD_WITH_NOPIE),1)
+CFLAGS += -no-pie
+endif
+
+ifeq ($(CAN_BUILD_I386),1)
+TEST_GEN_FILES += $(BINARIES_32)
+endif
+
+ifeq ($(CAN_BUILD_X86_64),1)
+TEST_GEN_FILES += $(BINARIES_64)
+endif
+else
+
+ifneq (,$(findstring $(MACHINE),ppc64))
+TEST_GEN_FILES += protection_keys
+endif
+
+endif
+
+ifneq (,$(filter $(MACHINE),arm64 ia64 mips64 parisc64 ppc64 riscv64 s390x sh64 sparc64 x86_64))
 TEST_GEN_FILES += va_128TBswitch
 TEST_GEN_FILES += virtual_address_range
+TEST_GEN_FILES += write_to_hugetlbfs
+endif
 
 TEST_PROGS := run_vmtests
 
@@ -26,6 +80,57 @@
 KSFT_KHDR_INSTALL := 1
 include ../lib.mk
 
+$(OUTPUT)/hmm-tests: LDLIBS += -lhugetlbfs -lpthread
+
+ifeq ($(MACHINE),x86_64)
+BINARIES_32 := $(patsubst %,$(OUTPUT)/%,$(BINARIES_32))
+BINARIES_64 := $(patsubst %,$(OUTPUT)/%,$(BINARIES_64))
+
+define gen-target-rule-32
+$(1) $(1)_32: $(OUTPUT)/$(1)_32
+.PHONY: $(1) $(1)_32
+endef
+
+define gen-target-rule-64
+$(1) $(1)_64: $(OUTPUT)/$(1)_64
+.PHONY: $(1) $(1)_64
+endef
+
+ifeq ($(CAN_BUILD_I386),1)
+$(BINARIES_32): CFLAGS += -m32
+$(BINARIES_32): LDLIBS += -lrt -ldl -lm
+$(BINARIES_32): $(OUTPUT)/%_32: %.c
+	$(CC) $(CFLAGS) $(EXTRA_CFLAGS) $(notdir $^) $(LDLIBS) -o $@
+$(foreach t,$(TARGETS),$(eval $(call gen-target-rule-32,$(t))))
+endif
+
+ifeq ($(CAN_BUILD_X86_64),1)
+$(BINARIES_64): CFLAGS += -m64
+$(BINARIES_64): LDLIBS += -lrt -ldl
+$(BINARIES_64): $(OUTPUT)/%_64: %.c
+	$(CC) $(CFLAGS) $(EXTRA_CFLAGS) $(notdir $^) $(LDLIBS) -o $@
+$(foreach t,$(TARGETS),$(eval $(call gen-target-rule-64,$(t))))
+endif
+
+# x86_64 users should be encouraged to install 32-bit libraries
+ifeq ($(CAN_BUILD_I386)$(CAN_BUILD_X86_64),01)
+all: warn_32bit_failure
+
+warn_32bit_failure:
+	@echo "Warning: you seem to have a broken 32-bit build" 2>&1;		\
+	echo  "environment. This will reduce test coverage of 64-bit" 2>&1;	\
+	echo  "kernels. If you are using a Debian-like distribution," 2>&1;	\
+	echo  "try:"; 2>&1;							\
+	echo  "";								\
+	echo  "  apt-get install gcc-multilib libc6-i386 libc6-dev-i386";	\
+	echo  "";								\
+	echo  "If you are using a Fedora-like distribution, try:";		\
+	echo  "";								\
+	echo  "  yum install glibc-devel.*i686";				\
+	exit 0;
+endif
+endif
+
 $(OUTPUT)/userfaultfd: LDLIBS += -lpthread
 
 $(OUTPUT)/mlock-random-test: LDLIBS += -lcap
diff --git a/tools/testing/selftests/vm/charge_reserved_hugetlb.sh b/tools/testing/selftests/vm/charge_reserved_hugetlb.sh
new file mode 100644
index 0000000..18d3368
--- /dev/null
+++ b/tools/testing/selftests/vm/charge_reserved_hugetlb.sh
@@ -0,0 +1,575 @@
+#!/bin/sh
+# SPDX-License-Identifier: GPL-2.0
+
+set -e
+
+if [[ $(id -u) -ne 0 ]]; then
+  echo "This test must be run as root. Skipping..."
+  exit 0
+fi
+
+fault_limit_file=limit_in_bytes
+reservation_limit_file=rsvd.limit_in_bytes
+fault_usage_file=usage_in_bytes
+reservation_usage_file=rsvd.usage_in_bytes
+
+if [[ "$1" == "-cgroup-v2" ]]; then
+  cgroup2=1
+  fault_limit_file=max
+  reservation_limit_file=rsvd.max
+  fault_usage_file=current
+  reservation_usage_file=rsvd.current
+fi
+
+cgroup_path=/dev/cgroup/memory
+if [[ ! -e $cgroup_path ]]; then
+  mkdir -p $cgroup_path
+  if [[ $cgroup2 ]]; then
+    mount -t cgroup2 none $cgroup_path
+  else
+    mount -t cgroup memory,hugetlb $cgroup_path
+  fi
+fi
+
+if [[ $cgroup2 ]]; then
+  echo "+hugetlb" >/dev/cgroup/memory/cgroup.subtree_control
+fi
+
+function cleanup() {
+  if [[ $cgroup2 ]]; then
+    echo $$ >$cgroup_path/cgroup.procs
+  else
+    echo $$ >$cgroup_path/tasks
+  fi
+
+  if [[ -e /mnt/huge ]]; then
+    rm -rf /mnt/huge/*
+    umount /mnt/huge || echo error
+    rmdir /mnt/huge
+  fi
+  if [[ -e $cgroup_path/hugetlb_cgroup_test ]]; then
+    rmdir $cgroup_path/hugetlb_cgroup_test
+  fi
+  if [[ -e $cgroup_path/hugetlb_cgroup_test1 ]]; then
+    rmdir $cgroup_path/hugetlb_cgroup_test1
+  fi
+  if [[ -e $cgroup_path/hugetlb_cgroup_test2 ]]; then
+    rmdir $cgroup_path/hugetlb_cgroup_test2
+  fi
+  echo 0 >/proc/sys/vm/nr_hugepages
+  echo CLEANUP DONE
+}
+
+function expect_equal() {
+  local expected="$1"
+  local actual="$2"
+  local error="$3"
+
+  if [[ "$expected" != "$actual" ]]; then
+    echo "expected ($expected) != actual ($actual): $3"
+    cleanup
+    exit 1
+  fi
+}
+
+function get_machine_hugepage_size() {
+  hpz=$(grep -i hugepagesize /proc/meminfo)
+  kb=${hpz:14:-3}
+  mb=$(($kb / 1024))
+  echo $mb
+}
+
+MB=$(get_machine_hugepage_size)
+
+function setup_cgroup() {
+  local name="$1"
+  local cgroup_limit="$2"
+  local reservation_limit="$3"
+
+  mkdir $cgroup_path/$name
+
+  echo writing cgroup limit: "$cgroup_limit"
+  echo "$cgroup_limit" >$cgroup_path/$name/hugetlb.${MB}MB.$fault_limit_file
+
+  echo writing reseravation limit: "$reservation_limit"
+  echo "$reservation_limit" > \
+    $cgroup_path/$name/hugetlb.${MB}MB.$reservation_limit_file
+
+  if [ -e "$cgroup_path/$name/cpuset.cpus" ]; then
+    echo 0 >$cgroup_path/$name/cpuset.cpus
+  fi
+  if [ -e "$cgroup_path/$name/cpuset.mems" ]; then
+    echo 0 >$cgroup_path/$name/cpuset.mems
+  fi
+}
+
+function wait_for_hugetlb_memory_to_get_depleted() {
+  local cgroup="$1"
+  local path="/dev/cgroup/memory/$cgroup/hugetlb.${MB}MB.$reservation_usage_file"
+  # Wait for hugetlbfs memory to get depleted.
+  while [ $(cat $path) != 0 ]; do
+    echo Waiting for hugetlb memory to get depleted.
+    cat $path
+    sleep 0.5
+  done
+}
+
+function wait_for_hugetlb_memory_to_get_reserved() {
+  local cgroup="$1"
+  local size="$2"
+
+  local path="/dev/cgroup/memory/$cgroup/hugetlb.${MB}MB.$reservation_usage_file"
+  # Wait for hugetlbfs memory to get written.
+  while [ $(cat $path) != $size ]; do
+    echo Waiting for hugetlb memory reservation to reach size $size.
+    cat $path
+    sleep 0.5
+  done
+}
+
+function wait_for_hugetlb_memory_to_get_written() {
+  local cgroup="$1"
+  local size="$2"
+
+  local path="/dev/cgroup/memory/$cgroup/hugetlb.${MB}MB.$fault_usage_file"
+  # Wait for hugetlbfs memory to get written.
+  while [ $(cat $path) != $size ]; do
+    echo Waiting for hugetlb memory to reach size $size.
+    cat $path
+    sleep 0.5
+  done
+}
+
+function write_hugetlbfs_and_get_usage() {
+  local cgroup="$1"
+  local size="$2"
+  local populate="$3"
+  local write="$4"
+  local path="$5"
+  local method="$6"
+  local private="$7"
+  local expect_failure="$8"
+  local reserve="$9"
+
+  # Function return values.
+  reservation_failed=0
+  oom_killed=0
+  hugetlb_difference=0
+  reserved_difference=0
+
+  local hugetlb_usage=$cgroup_path/$cgroup/hugetlb.${MB}MB.$fault_usage_file
+  local reserved_usage=$cgroup_path/$cgroup/hugetlb.${MB}MB.$reservation_usage_file
+
+  local hugetlb_before=$(cat $hugetlb_usage)
+  local reserved_before=$(cat $reserved_usage)
+
+  echo
+  echo Starting:
+  echo hugetlb_usage="$hugetlb_before"
+  echo reserved_usage="$reserved_before"
+  echo expect_failure is "$expect_failure"
+
+  output=$(mktemp)
+  set +e
+  if [[ "$method" == "1" ]] || [[ "$method" == 2 ]] ||
+    [[ "$private" == "-r" ]] && [[ "$expect_failure" != 1 ]]; then
+
+    bash write_hugetlb_memory.sh "$size" "$populate" "$write" \
+      "$cgroup" "$path" "$method" "$private" "-l" "$reserve" 2>&1 | tee $output &
+
+    local write_result=$?
+    local write_pid=$!
+
+    until grep -q -i "DONE" $output; do
+      echo waiting for DONE signal.
+      if ! ps $write_pid > /dev/null
+      then
+        echo "FAIL: The write died"
+        cleanup
+        exit 1
+      fi
+      sleep 0.5
+    done
+
+    echo ================= write_hugetlb_memory.sh output is:
+    cat $output
+    echo ================= end output.
+
+    if [[ "$populate" == "-o" ]] || [[ "$write" == "-w" ]]; then
+      wait_for_hugetlb_memory_to_get_written "$cgroup" "$size"
+    elif [[ "$reserve" != "-n" ]]; then
+      wait_for_hugetlb_memory_to_get_reserved "$cgroup" "$size"
+    else
+      # This case doesn't produce visible effects, but we still have
+      # to wait for the async process to start and execute...
+      sleep 0.5
+    fi
+
+    echo write_result is $write_result
+  else
+    bash write_hugetlb_memory.sh "$size" "$populate" "$write" \
+      "$cgroup" "$path" "$method" "$private" "$reserve"
+    local write_result=$?
+
+    if [[ "$reserve" != "-n" ]]; then
+      wait_for_hugetlb_memory_to_get_reserved "$cgroup" "$size"
+    fi
+  fi
+  set -e
+
+  if [[ "$write_result" == 1 ]]; then
+    reservation_failed=1
+  fi
+
+  # On linus/master, the above process gets SIGBUS'd on oomkill, with
+  # return code 135. On earlier kernels, it gets actual oomkill, with return
+  # code 137, so just check for both conditions in case we're testing
+  # against an earlier kernel.
+  if [[ "$write_result" == 135 ]] || [[ "$write_result" == 137 ]]; then
+    oom_killed=1
+  fi
+
+  local hugetlb_after=$(cat $hugetlb_usage)
+  local reserved_after=$(cat $reserved_usage)
+
+  echo After write:
+  echo hugetlb_usage="$hugetlb_after"
+  echo reserved_usage="$reserved_after"
+
+  hugetlb_difference=$(($hugetlb_after - $hugetlb_before))
+  reserved_difference=$(($reserved_after - $reserved_before))
+}
+
+function cleanup_hugetlb_memory() {
+  set +e
+  local cgroup="$1"
+  if [[ "$(pgrep -f write_to_hugetlbfs)" != "" ]]; then
+    echo killing write_to_hugetlbfs
+    killall -2 write_to_hugetlbfs
+    wait_for_hugetlb_memory_to_get_depleted $cgroup
+  fi
+  set -e
+
+  if [[ -e /mnt/huge ]]; then
+    rm -rf /mnt/huge/*
+    umount /mnt/huge
+    rmdir /mnt/huge
+  fi
+}
+
+function run_test() {
+  local size=$(($1 * ${MB} * 1024 * 1024))
+  local populate="$2"
+  local write="$3"
+  local cgroup_limit=$(($4 * ${MB} * 1024 * 1024))
+  local reservation_limit=$(($5 * ${MB} * 1024 * 1024))
+  local nr_hugepages="$6"
+  local method="$7"
+  local private="$8"
+  local expect_failure="$9"
+  local reserve="${10}"
+
+  # Function return values.
+  hugetlb_difference=0
+  reserved_difference=0
+  reservation_failed=0
+  oom_killed=0
+
+  echo nr hugepages = "$nr_hugepages"
+  echo "$nr_hugepages" >/proc/sys/vm/nr_hugepages
+
+  setup_cgroup "hugetlb_cgroup_test" "$cgroup_limit" "$reservation_limit"
+
+  mkdir -p /mnt/huge
+  mount -t hugetlbfs -o pagesize=${MB}M,size=256M none /mnt/huge
+
+  write_hugetlbfs_and_get_usage "hugetlb_cgroup_test" "$size" "$populate" \
+    "$write" "/mnt/huge/test" "$method" "$private" "$expect_failure" \
+    "$reserve"
+
+  cleanup_hugetlb_memory "hugetlb_cgroup_test"
+
+  local final_hugetlb=$(cat $cgroup_path/hugetlb_cgroup_test/hugetlb.${MB}MB.$fault_usage_file)
+  local final_reservation=$(cat $cgroup_path/hugetlb_cgroup_test/hugetlb.${MB}MB.$reservation_usage_file)
+
+  echo $hugetlb_difference
+  echo $reserved_difference
+  expect_equal "0" "$final_hugetlb" "final hugetlb is not zero"
+  expect_equal "0" "$final_reservation" "final reservation is not zero"
+}
+
+function run_multiple_cgroup_test() {
+  local size1="$1"
+  local populate1="$2"
+  local write1="$3"
+  local cgroup_limit1="$4"
+  local reservation_limit1="$5"
+
+  local size2="$6"
+  local populate2="$7"
+  local write2="$8"
+  local cgroup_limit2="$9"
+  local reservation_limit2="${10}"
+
+  local nr_hugepages="${11}"
+  local method="${12}"
+  local private="${13}"
+  local expect_failure="${14}"
+  local reserve="${15}"
+
+  # Function return values.
+  hugetlb_difference1=0
+  reserved_difference1=0
+  reservation_failed1=0
+  oom_killed1=0
+
+  hugetlb_difference2=0
+  reserved_difference2=0
+  reservation_failed2=0
+  oom_killed2=0
+
+  echo nr hugepages = "$nr_hugepages"
+  echo "$nr_hugepages" >/proc/sys/vm/nr_hugepages
+
+  setup_cgroup "hugetlb_cgroup_test1" "$cgroup_limit1" "$reservation_limit1"
+  setup_cgroup "hugetlb_cgroup_test2" "$cgroup_limit2" "$reservation_limit2"
+
+  mkdir -p /mnt/huge
+  mount -t hugetlbfs -o pagesize=${MB}M,size=256M none /mnt/huge
+
+  write_hugetlbfs_and_get_usage "hugetlb_cgroup_test1" "$size1" \
+    "$populate1" "$write1" "/mnt/huge/test1" "$method" "$private" \
+    "$expect_failure" "$reserve"
+
+  hugetlb_difference1=$hugetlb_difference
+  reserved_difference1=$reserved_difference
+  reservation_failed1=$reservation_failed
+  oom_killed1=$oom_killed
+
+  local cgroup1_hugetlb_usage=$cgroup_path/hugetlb_cgroup_test1/hugetlb.${MB}MB.$fault_usage_file
+  local cgroup1_reservation_usage=$cgroup_path/hugetlb_cgroup_test1/hugetlb.${MB}MB.$reservation_usage_file
+  local cgroup2_hugetlb_usage=$cgroup_path/hugetlb_cgroup_test2/hugetlb.${MB}MB.$fault_usage_file
+  local cgroup2_reservation_usage=$cgroup_path/hugetlb_cgroup_test2/hugetlb.${MB}MB.$reservation_usage_file
+
+  local usage_before_second_write=$(cat $cgroup1_hugetlb_usage)
+  local reservation_usage_before_second_write=$(cat $cgroup1_reservation_usage)
+
+  write_hugetlbfs_and_get_usage "hugetlb_cgroup_test2" "$size2" \
+    "$populate2" "$write2" "/mnt/huge/test2" "$method" "$private" \
+    "$expect_failure" "$reserve"
+
+  hugetlb_difference2=$hugetlb_difference
+  reserved_difference2=$reserved_difference
+  reservation_failed2=$reservation_failed
+  oom_killed2=$oom_killed
+
+  expect_equal "$usage_before_second_write" \
+    "$(cat $cgroup1_hugetlb_usage)" "Usage changed."
+  expect_equal "$reservation_usage_before_second_write" \
+    "$(cat $cgroup1_reservation_usage)" "Reservation usage changed."
+
+  cleanup_hugetlb_memory
+
+  local final_hugetlb=$(cat $cgroup1_hugetlb_usage)
+  local final_reservation=$(cat $cgroup1_reservation_usage)
+
+  expect_equal "0" "$final_hugetlb" \
+    "hugetlbt_cgroup_test1 final hugetlb is not zero"
+  expect_equal "0" "$final_reservation" \
+    "hugetlbt_cgroup_test1 final reservation is not zero"
+
+  local final_hugetlb=$(cat $cgroup2_hugetlb_usage)
+  local final_reservation=$(cat $cgroup2_reservation_usage)
+
+  expect_equal "0" "$final_hugetlb" \
+    "hugetlb_cgroup_test2 final hugetlb is not zero"
+  expect_equal "0" "$final_reservation" \
+    "hugetlb_cgroup_test2 final reservation is not zero"
+}
+
+cleanup
+
+for populate in "" "-o"; do
+  for method in 0 1 2; do
+    for private in "" "-r"; do
+      for reserve in "" "-n"; do
+
+        # Skip mmap(MAP_HUGETLB | MAP_SHARED). Doesn't seem to be supported.
+        if [[ "$method" == 1 ]] && [[ "$private" == "" ]]; then
+          continue
+        fi
+
+        # Skip populated shmem tests. Doesn't seem to be supported.
+        if [[ "$method" == 2"" ]] && [[ "$populate" == "-o" ]]; then
+          continue
+        fi
+
+        if [[ "$method" == 2"" ]] && [[ "$reserve" == "-n" ]]; then
+          continue
+        fi
+
+        cleanup
+        echo
+        echo
+        echo
+        echo Test normal case.
+        echo private=$private, populate=$populate, method=$method, reserve=$reserve
+        run_test 5 "$populate" "" 10 10 10 "$method" "$private" "0" "$reserve"
+
+        echo Memory charged to hugtlb=$hugetlb_difference
+        echo Memory charged to reservation=$reserved_difference
+
+        if [[ "$populate" == "-o" ]]; then
+          expect_equal "$((5 * $MB * 1024 * 1024))" "$hugetlb_difference" \
+            "Reserved memory charged to hugetlb cgroup."
+        else
+          expect_equal "0" "$hugetlb_difference" \
+            "Reserved memory charged to hugetlb cgroup."
+        fi
+
+        if [[ "$reserve" != "-n" ]] || [[ "$populate" == "-o" ]]; then
+          expect_equal "$((5 * $MB * 1024 * 1024))" "$reserved_difference" \
+            "Reserved memory not charged to reservation usage."
+        else
+          expect_equal "0" "$reserved_difference" \
+            "Reserved memory not charged to reservation usage."
+        fi
+
+        echo 'PASS'
+
+        cleanup
+        echo
+        echo
+        echo
+        echo Test normal case with write.
+        echo private=$private, populate=$populate, method=$method, reserve=$reserve
+        run_test 5 "$populate" '-w' 5 5 10 "$method" "$private" "0" "$reserve"
+
+        echo Memory charged to hugtlb=$hugetlb_difference
+        echo Memory charged to reservation=$reserved_difference
+
+        expect_equal "$((5 * $MB * 1024 * 1024))" "$hugetlb_difference" \
+          "Reserved memory charged to hugetlb cgroup."
+
+        expect_equal "$((5 * $MB * 1024 * 1024))" "$reserved_difference" \
+          "Reserved memory not charged to reservation usage."
+
+        echo 'PASS'
+
+        cleanup
+        continue
+        echo
+        echo
+        echo
+        echo Test more than reservation case.
+        echo private=$private, populate=$populate, method=$method, reserve=$reserve
+
+        if [ "$reserve" != "-n" ]; then
+          run_test "5" "$populate" '' "10" "2" "10" "$method" "$private" "1" \
+            "$reserve"
+
+          expect_equal "1" "$reservation_failed" "Reservation succeeded."
+        fi
+
+        echo 'PASS'
+
+        cleanup
+
+        echo
+        echo
+        echo
+        echo Test more than cgroup limit case.
+        echo private=$private, populate=$populate, method=$method, reserve=$reserve
+
+        # Not sure if shm memory can be cleaned up when the process gets sigbus'd.
+        if [[ "$method" != 2 ]]; then
+          run_test 5 "$populate" "-w" 2 10 10 "$method" "$private" "1" "$reserve"
+
+          expect_equal "1" "$oom_killed" "Not oom killed."
+        fi
+        echo 'PASS'
+
+        cleanup
+
+        echo
+        echo
+        echo
+        echo Test normal case, multiple cgroups.
+        echo private=$private, populate=$populate, method=$method, reserve=$reserve
+        run_multiple_cgroup_test "3" "$populate" "" "10" "10" "5" \
+          "$populate" "" "10" "10" "10" \
+          "$method" "$private" "0" "$reserve"
+
+        echo Memory charged to hugtlb1=$hugetlb_difference1
+        echo Memory charged to reservation1=$reserved_difference1
+        echo Memory charged to hugtlb2=$hugetlb_difference2
+        echo Memory charged to reservation2=$reserved_difference2
+
+        if [[ "$reserve" != "-n" ]] || [[ "$populate" == "-o" ]]; then
+          expect_equal "3" "$reserved_difference1" \
+            "Incorrect reservations charged to cgroup 1."
+
+          expect_equal "5" "$reserved_difference2" \
+            "Incorrect reservation charged to cgroup 2."
+
+        else
+          expect_equal "0" "$reserved_difference1" \
+            "Incorrect reservations charged to cgroup 1."
+
+          expect_equal "0" "$reserved_difference2" \
+            "Incorrect reservation charged to cgroup 2."
+        fi
+
+        if [[ "$populate" == "-o" ]]; then
+          expect_equal "3" "$hugetlb_difference1" \
+            "Incorrect hugetlb charged to cgroup 1."
+
+          expect_equal "5" "$hugetlb_difference2" \
+            "Incorrect hugetlb charged to cgroup 2."
+
+        else
+          expect_equal "0" "$hugetlb_difference1" \
+            "Incorrect hugetlb charged to cgroup 1."
+
+          expect_equal "0" "$hugetlb_difference2" \
+            "Incorrect hugetlb charged to cgroup 2."
+        fi
+        echo 'PASS'
+
+        cleanup
+        echo
+        echo
+        echo
+        echo Test normal case with write, multiple cgroups.
+        echo private=$private, populate=$populate, method=$method, reserve=$reserve
+        run_multiple_cgroup_test "3" "$populate" "-w" "10" "10" "5" \
+          "$populate" "-w" "10" "10" "10" \
+          "$method" "$private" "0" "$reserve"
+
+        echo Memory charged to hugtlb1=$hugetlb_difference1
+        echo Memory charged to reservation1=$reserved_difference1
+        echo Memory charged to hugtlb2=$hugetlb_difference2
+        echo Memory charged to reservation2=$reserved_difference2
+
+        expect_equal "3" "$hugetlb_difference1" \
+          "Incorrect hugetlb charged to cgroup 1."
+
+        expect_equal "3" "$reserved_difference1" \
+          "Incorrect reservation charged to cgroup 1."
+
+        expect_equal "5" "$hugetlb_difference2" \
+          "Incorrect hugetlb charged to cgroup 2."
+
+        expect_equal "5" "$reserved_difference2" \
+          "Incorrected reservation charged to cgroup 2."
+        echo 'PASS'
+
+        cleanup
+
+      done # reserve
+    done   # private
+  done     # populate
+done       # method
+
+umount $cgroup_path
+rmdir $cgroup_path
diff --git a/tools/testing/selftests/vm/compaction_test.c b/tools/testing/selftests/vm/compaction_test.c
index bcec712..9b42014 100644
--- a/tools/testing/selftests/vm/compaction_test.c
+++ b/tools/testing/selftests/vm/compaction_test.c
@@ -18,7 +18,8 @@
 
 #include "../kselftest.h"
 
-#define MAP_SIZE 1048576
+#define MAP_SIZE_MB	100
+#define MAP_SIZE	(MAP_SIZE_MB * 1024 * 1024)
 
 struct map_list {
 	void *map;
@@ -165,7 +166,7 @@
 	void *map = NULL;
 	unsigned long mem_free = 0;
 	unsigned long hugepage_size = 0;
-	unsigned long mem_fragmentable = 0;
+	long mem_fragmentable_MB = 0;
 
 	if (prereq() != 0) {
 		printf("Either the sysctl compact_unevictable_allowed is not\n"
@@ -190,9 +191,9 @@
 		return -1;
 	}
 
-	mem_fragmentable = mem_free * 0.8 / 1024;
+	mem_fragmentable_MB = mem_free * 0.8 / 1024;
 
-	while (mem_fragmentable > 0) {
+	while (mem_fragmentable_MB > 0) {
 		map = mmap(NULL, MAP_SIZE, PROT_READ | PROT_WRITE,
 			   MAP_ANONYMOUS | MAP_PRIVATE | MAP_LOCKED, -1, 0);
 		if (map == MAP_FAILED)
@@ -213,7 +214,7 @@
 		for (i = 0; i < MAP_SIZE; i += page_size)
 			*(unsigned long *)(map + i) = (unsigned long)map + i;
 
-		mem_fragmentable--;
+		mem_fragmentable_MB -= MAP_SIZE_MB;
 	}
 
 	for (entry = list; entry != NULL; entry = entry->next) {
diff --git a/tools/testing/selftests/vm/config b/tools/testing/selftests/vm/config
index 93b90a9..69dd0d1 100644
--- a/tools/testing/selftests/vm/config
+++ b/tools/testing/selftests/vm/config
@@ -1,3 +1,6 @@
 CONFIG_SYSVIPC=y
 CONFIG_USERFAULTFD=y
 CONFIG_TEST_VMALLOC=m
+CONFIG_DEVICE_PRIVATE=y
+CONFIG_TEST_HMM=m
+CONFIG_GUP_BENCHMARK=y
diff --git a/tools/testing/selftests/vm/gup_benchmark.c b/tools/testing/selftests/vm/gup_benchmark.c
index 485cf06..1d43593 100644
--- a/tools/testing/selftests/vm/gup_benchmark.c
+++ b/tools/testing/selftests/vm/gup_benchmark.c
@@ -15,8 +15,15 @@
 #define PAGE_SIZE sysconf(_SC_PAGESIZE)
 
 #define GUP_FAST_BENCHMARK	_IOWR('g', 1, struct gup_benchmark)
-#define GUP_LONGTERM_BENCHMARK	_IOWR('g', 2, struct gup_benchmark)
-#define GUP_BENCHMARK		_IOWR('g', 3, struct gup_benchmark)
+#define GUP_BENCHMARK		_IOWR('g', 2, struct gup_benchmark)
+
+/* Similar to above, but use FOLL_PIN instead of FOLL_GET. */
+#define PIN_FAST_BENCHMARK	_IOWR('g', 3, struct gup_benchmark)
+#define PIN_BENCHMARK		_IOWR('g', 4, struct gup_benchmark)
+#define PIN_LONGTERM_BENCHMARK	_IOWR('g', 5, struct gup_benchmark)
+
+/* Just the flags we need, copied from mm.h: */
+#define FOLL_WRITE	0x01	/* check pte is writable */
 
 struct gup_benchmark {
 	__u64 get_delta_usec;
@@ -37,8 +44,17 @@
 	char *file = "/dev/zero";
 	char *p;
 
-	while ((opt = getopt(argc, argv, "m:r:n:f:tTLUwSH")) != -1) {
+	while ((opt = getopt(argc, argv, "m:r:n:f:abtTLUuwSH")) != -1) {
 		switch (opt) {
+		case 'a':
+			cmd = PIN_FAST_BENCHMARK;
+			break;
+		case 'b':
+			cmd = PIN_BENCHMARK;
+			break;
+		case 'L':
+			cmd = PIN_LONGTERM_BENCHMARK;
+			break;
 		case 'm':
 			size = atoi(optarg) * MB;
 			break;
@@ -54,12 +70,12 @@
 		case 'T':
 			thp = 0;
 			break;
-		case 'L':
-			cmd = GUP_LONGTERM_BENCHMARK;
-			break;
 		case 'U':
 			cmd = GUP_BENCHMARK;
 			break;
+		case 'u':
+			cmd = GUP_FAST_BENCHMARK;
+			break;
 		case 'w':
 			write = 1;
 			break;
@@ -85,15 +101,20 @@
 	}
 
 	gup.nr_pages_per_call = nr_pages;
-	gup.flags = write;
+	if (write)
+		gup.flags |= FOLL_WRITE;
 
 	fd = open("/sys/kernel/debug/gup_benchmark", O_RDWR);
-	if (fd == -1)
-		perror("open"), exit(1);
+	if (fd == -1) {
+		perror("open");
+		exit(1);
+	}
 
 	p = mmap(NULL, size, PROT_READ | PROT_WRITE, flags, filed, 0);
-	if (p == MAP_FAILED)
-		perror("mmap"), exit(1);
+	if (p == MAP_FAILED) {
+		perror("mmap");
+		exit(1);
+	}
 	gup.addr = (unsigned long)p;
 
 	if (thp == 1)
@@ -106,8 +127,10 @@
 
 	for (i = 0; i < repeats; i++) {
 		gup.size = size;
-		if (ioctl(fd, cmd, &gup))
-			perror("ioctl"), exit(1);
+		if (ioctl(fd, cmd, &gup)) {
+			perror("ioctl");
+			exit(1);
+		}
 
 		printf("Time: get:%lld put:%lld us", gup.get_delta_usec,
 			gup.put_delta_usec);
diff --git a/tools/testing/selftests/vm/hmm-tests.c b/tools/testing/selftests/vm/hmm-tests.c
new file mode 100644
index 0000000..426dccc
--- /dev/null
+++ b/tools/testing/selftests/vm/hmm-tests.c
@@ -0,0 +1,1522 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * HMM stands for Heterogeneous Memory Management, it is a helper layer inside
+ * the linux kernel to help device drivers mirror a process address space in
+ * the device. This allows the device to use the same address space which
+ * makes communication and data exchange a lot easier.
+ *
+ * This framework's sole purpose is to exercise various code paths inside
+ * the kernel to make sure that HMM performs as expected and to flush out any
+ * bugs.
+ */
+
+#include "../kselftest_harness.h"
+
+#include <errno.h>
+#include <fcntl.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <stdint.h>
+#include <unistd.h>
+#include <strings.h>
+#include <time.h>
+#include <pthread.h>
+#include <hugetlbfs.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/mman.h>
+#include <sys/ioctl.h>
+
+/*
+ * This is a private UAPI to the kernel test module so it isn't exported
+ * in the usual include/uapi/... directory.
+ */
+#include "../../../../lib/test_hmm_uapi.h"
+
+struct hmm_buffer {
+	void		*ptr;
+	void		*mirror;
+	unsigned long	size;
+	int		fd;
+	uint64_t	cpages;
+	uint64_t	faults;
+};
+
+#define TWOMEG		(1 << 21)
+#define HMM_BUFFER_SIZE (1024 << 12)
+#define HMM_PATH_MAX    64
+#define NTIMES		10
+
+#define ALIGN(x, a) (((x) + (a - 1)) & (~((a) - 1)))
+
+FIXTURE(hmm)
+{
+	int		fd;
+	unsigned int	page_size;
+	unsigned int	page_shift;
+};
+
+FIXTURE(hmm2)
+{
+	int		fd0;
+	int		fd1;
+	unsigned int	page_size;
+	unsigned int	page_shift;
+};
+
+static int hmm_open(int unit)
+{
+	char pathname[HMM_PATH_MAX];
+	int fd;
+
+	snprintf(pathname, sizeof(pathname), "/dev/hmm_dmirror%d", unit);
+	fd = open(pathname, O_RDWR, 0);
+	if (fd < 0)
+		fprintf(stderr, "could not open hmm dmirror driver (%s)\n",
+			pathname);
+	return fd;
+}
+
+FIXTURE_SETUP(hmm)
+{
+	self->page_size = sysconf(_SC_PAGE_SIZE);
+	self->page_shift = ffs(self->page_size) - 1;
+
+	self->fd = hmm_open(0);
+	ASSERT_GE(self->fd, 0);
+}
+
+FIXTURE_SETUP(hmm2)
+{
+	self->page_size = sysconf(_SC_PAGE_SIZE);
+	self->page_shift = ffs(self->page_size) - 1;
+
+	self->fd0 = hmm_open(0);
+	ASSERT_GE(self->fd0, 0);
+	self->fd1 = hmm_open(1);
+	ASSERT_GE(self->fd1, 0);
+}
+
+FIXTURE_TEARDOWN(hmm)
+{
+	int ret = close(self->fd);
+
+	ASSERT_EQ(ret, 0);
+	self->fd = -1;
+}
+
+FIXTURE_TEARDOWN(hmm2)
+{
+	int ret = close(self->fd0);
+
+	ASSERT_EQ(ret, 0);
+	self->fd0 = -1;
+
+	ret = close(self->fd1);
+	ASSERT_EQ(ret, 0);
+	self->fd1 = -1;
+}
+
+static int hmm_dmirror_cmd(int fd,
+			   unsigned long request,
+			   struct hmm_buffer *buffer,
+			   unsigned long npages)
+{
+	struct hmm_dmirror_cmd cmd;
+	int ret;
+
+	/* Simulate a device reading system memory. */
+	cmd.addr = (__u64)buffer->ptr;
+	cmd.ptr = (__u64)buffer->mirror;
+	cmd.npages = npages;
+
+	for (;;) {
+		ret = ioctl(fd, request, &cmd);
+		if (ret == 0)
+			break;
+		if (errno == EINTR)
+			continue;
+		return -errno;
+	}
+	buffer->cpages = cmd.cpages;
+	buffer->faults = cmd.faults;
+
+	return 0;
+}
+
+static void hmm_buffer_free(struct hmm_buffer *buffer)
+{
+	if (buffer == NULL)
+		return;
+
+	if (buffer->ptr)
+		munmap(buffer->ptr, buffer->size);
+	free(buffer->mirror);
+	free(buffer);
+}
+
+/*
+ * Create a temporary file that will be deleted on close.
+ */
+static int hmm_create_file(unsigned long size)
+{
+	char path[HMM_PATH_MAX];
+	int fd;
+
+	strcpy(path, "/tmp");
+	fd = open(path, O_TMPFILE | O_EXCL | O_RDWR, 0600);
+	if (fd >= 0) {
+		int r;
+
+		do {
+			r = ftruncate(fd, size);
+		} while (r == -1 && errno == EINTR);
+		if (!r)
+			return fd;
+		close(fd);
+	}
+	return -1;
+}
+
+/*
+ * Return a random unsigned number.
+ */
+static unsigned int hmm_random(void)
+{
+	static int fd = -1;
+	unsigned int r;
+
+	if (fd < 0) {
+		fd = open("/dev/urandom", O_RDONLY);
+		if (fd < 0) {
+			fprintf(stderr, "%s:%d failed to open /dev/urandom\n",
+					__FILE__, __LINE__);
+			return ~0U;
+		}
+	}
+	read(fd, &r, sizeof(r));
+	return r;
+}
+
+static void hmm_nanosleep(unsigned int n)
+{
+	struct timespec t;
+
+	t.tv_sec = 0;
+	t.tv_nsec = n;
+	nanosleep(&t, NULL);
+}
+
+/*
+ * Simple NULL test of device open/close.
+ */
+TEST_F(hmm, open_close)
+{
+}
+
+/*
+ * Read private anonymous memory.
+ */
+TEST_F(hmm, anon_read)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	int *ptr;
+	int ret;
+	int val;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ | PROT_WRITE,
+			   MAP_PRIVATE | MAP_ANONYMOUS,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	/*
+	 * Initialize buffer in system memory but leave the first two pages
+	 * zero (pte_none and pfn_zero).
+	 */
+	i = 2 * self->page_size / sizeof(*ptr);
+	for (ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+
+	/* Set buffer permission to read-only. */
+	ret = mprotect(buffer->ptr, size, PROT_READ);
+	ASSERT_EQ(ret, 0);
+
+	/* Populate the CPU page table with a special zero page. */
+	val = *(int *)(buffer->ptr + self->page_size);
+	ASSERT_EQ(val, 0);
+
+	/* Simulate a device reading system memory. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_READ, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+	ASSERT_EQ(buffer->faults, 1);
+
+	/* Check what the device read. */
+	ptr = buffer->mirror;
+	for (i = 0; i < 2 * self->page_size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], 0);
+	for (; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Read private anonymous memory which has been protected with
+ * mprotect() PROT_NONE.
+ */
+TEST_F(hmm, anon_read_prot)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	int *ptr;
+	int ret;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ | PROT_WRITE,
+			   MAP_PRIVATE | MAP_ANONYMOUS,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	/* Initialize buffer in system memory. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+
+	/* Initialize mirror buffer so we can verify it isn't written. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ptr[i] = -i;
+
+	/* Protect buffer from reading. */
+	ret = mprotect(buffer->ptr, size, PROT_NONE);
+	ASSERT_EQ(ret, 0);
+
+	/* Simulate a device reading system memory. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_READ, buffer, npages);
+	ASSERT_EQ(ret, -EFAULT);
+
+	/* Allow CPU to read the buffer so we can check it. */
+	ret = mprotect(buffer->ptr, size, PROT_READ);
+	ASSERT_EQ(ret, 0);
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	/* Check what the device read. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], -i);
+
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Write private anonymous memory.
+ */
+TEST_F(hmm, anon_write)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	int *ptr;
+	int ret;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ | PROT_WRITE,
+			   MAP_PRIVATE | MAP_ANONYMOUS,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	/* Initialize data that the device will write to buffer->ptr. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+
+	/* Simulate a device writing system memory. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_WRITE, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+	ASSERT_EQ(buffer->faults, 1);
+
+	/* Check what the device wrote. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Write private anonymous memory which has been protected with
+ * mprotect() PROT_READ.
+ */
+TEST_F(hmm, anon_write_prot)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	int *ptr;
+	int ret;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ,
+			   MAP_PRIVATE | MAP_ANONYMOUS,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	/* Simulate a device reading a zero page of memory. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_READ, buffer, 1);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, 1);
+	ASSERT_EQ(buffer->faults, 1);
+
+	/* Initialize data that the device will write to buffer->ptr. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+
+	/* Simulate a device writing system memory. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_WRITE, buffer, npages);
+	ASSERT_EQ(ret, -EPERM);
+
+	/* Check what the device wrote. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], 0);
+
+	/* Now allow writing and see that the zero page is replaced. */
+	ret = mprotect(buffer->ptr, size, PROT_WRITE | PROT_READ);
+	ASSERT_EQ(ret, 0);
+
+	/* Simulate a device writing system memory. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_WRITE, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+	ASSERT_EQ(buffer->faults, 1);
+
+	/* Check what the device wrote. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Check that a device writing an anonymous private mapping
+ * will copy-on-write if a child process inherits the mapping.
+ */
+TEST_F(hmm, anon_write_child)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	int *ptr;
+	pid_t pid;
+	int child_fd;
+	int ret;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ | PROT_WRITE,
+			   MAP_PRIVATE | MAP_ANONYMOUS,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	/* Initialize buffer->ptr so we can tell if it is written. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+
+	/* Initialize data that the device will write to buffer->ptr. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ptr[i] = -i;
+
+	pid = fork();
+	if (pid == -1)
+		ASSERT_EQ(pid, 0);
+	if (pid != 0) {
+		waitpid(pid, &ret, 0);
+		ASSERT_EQ(WIFEXITED(ret), 1);
+
+		/* Check that the parent's buffer did not change. */
+		for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+			ASSERT_EQ(ptr[i], i);
+		return;
+	}
+
+	/* Check that we see the parent's values. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], -i);
+
+	/* The child process needs its own mirror to its own mm. */
+	child_fd = hmm_open(0);
+	ASSERT_GE(child_fd, 0);
+
+	/* Simulate a device writing system memory. */
+	ret = hmm_dmirror_cmd(child_fd, HMM_DMIRROR_WRITE, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+	ASSERT_EQ(buffer->faults, 1);
+
+	/* Check what the device wrote. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], -i);
+
+	close(child_fd);
+	exit(0);
+}
+
+/*
+ * Check that a device writing an anonymous shared mapping
+ * will not copy-on-write if a child process inherits the mapping.
+ */
+TEST_F(hmm, anon_write_child_shared)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	int *ptr;
+	pid_t pid;
+	int child_fd;
+	int ret;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ | PROT_WRITE,
+			   MAP_SHARED | MAP_ANONYMOUS,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	/* Initialize buffer->ptr so we can tell if it is written. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+
+	/* Initialize data that the device will write to buffer->ptr. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ptr[i] = -i;
+
+	pid = fork();
+	if (pid == -1)
+		ASSERT_EQ(pid, 0);
+	if (pid != 0) {
+		waitpid(pid, &ret, 0);
+		ASSERT_EQ(WIFEXITED(ret), 1);
+
+		/* Check that the parent's buffer did change. */
+		for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+			ASSERT_EQ(ptr[i], -i);
+		return;
+	}
+
+	/* Check that we see the parent's values. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], -i);
+
+	/* The child process needs its own mirror to its own mm. */
+	child_fd = hmm_open(0);
+	ASSERT_GE(child_fd, 0);
+
+	/* Simulate a device writing system memory. */
+	ret = hmm_dmirror_cmd(child_fd, HMM_DMIRROR_WRITE, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+	ASSERT_EQ(buffer->faults, 1);
+
+	/* Check what the device wrote. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], -i);
+
+	close(child_fd);
+	exit(0);
+}
+
+/*
+ * Write private anonymous huge page.
+ */
+TEST_F(hmm, anon_write_huge)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	void *old_ptr;
+	void *map;
+	int *ptr;
+	int ret;
+
+	size = 2 * TWOMEG;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ | PROT_WRITE,
+			   MAP_PRIVATE | MAP_ANONYMOUS,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	size = TWOMEG;
+	npages = size >> self->page_shift;
+	map = (void *)ALIGN((uintptr_t)buffer->ptr, size);
+	ret = madvise(map, size, MADV_HUGEPAGE);
+	ASSERT_EQ(ret, 0);
+	old_ptr = buffer->ptr;
+	buffer->ptr = map;
+
+	/* Initialize data that the device will write to buffer->ptr. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+
+	/* Simulate a device writing system memory. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_WRITE, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+	ASSERT_EQ(buffer->faults, 1);
+
+	/* Check what the device wrote. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	buffer->ptr = old_ptr;
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Write huge TLBFS page.
+ */
+TEST_F(hmm, anon_write_hugetlbfs)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	int *ptr;
+	int ret;
+	long pagesizes[4];
+	int n, idx;
+
+	/* Skip test if we can't allocate a hugetlbfs page. */
+
+	n = gethugepagesizes(pagesizes, 4);
+	if (n <= 0)
+		SKIP(return, "Huge page size could not be determined");
+	for (idx = 0; --n > 0; ) {
+		if (pagesizes[n] < pagesizes[idx])
+			idx = n;
+	}
+	size = ALIGN(TWOMEG, pagesizes[idx]);
+	npages = size >> self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->ptr = get_hugepage_region(size, GHR_STRICT);
+	if (buffer->ptr == NULL) {
+		free(buffer);
+		SKIP(return, "Huge page could not be allocated");
+	}
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	/* Initialize data that the device will write to buffer->ptr. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+
+	/* Simulate a device writing system memory. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_WRITE, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+	ASSERT_EQ(buffer->faults, 1);
+
+	/* Check what the device wrote. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	free_hugepage_region(buffer->ptr);
+	buffer->ptr = NULL;
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Read mmap'ed file memory.
+ */
+TEST_F(hmm, file_read)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	int *ptr;
+	int ret;
+	int fd;
+	ssize_t len;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	fd = hmm_create_file(size);
+	ASSERT_GE(fd, 0);
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = fd;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	/* Write initial contents of the file. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+	len = pwrite(fd, buffer->mirror, size, 0);
+	ASSERT_EQ(len, size);
+	memset(buffer->mirror, 0, size);
+
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ,
+			   MAP_SHARED,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	/* Simulate a device reading system memory. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_READ, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+	ASSERT_EQ(buffer->faults, 1);
+
+	/* Check what the device read. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Write mmap'ed file memory.
+ */
+TEST_F(hmm, file_write)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	int *ptr;
+	int ret;
+	int fd;
+	ssize_t len;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	fd = hmm_create_file(size);
+	ASSERT_GE(fd, 0);
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = fd;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ | PROT_WRITE,
+			   MAP_SHARED,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	/* Initialize data that the device will write to buffer->ptr. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+
+	/* Simulate a device writing system memory. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_WRITE, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+	ASSERT_EQ(buffer->faults, 1);
+
+	/* Check what the device wrote. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	/* Check that the device also wrote the file. */
+	len = pread(fd, buffer->mirror, size, 0);
+	ASSERT_EQ(len, size);
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Migrate anonymous memory to device private memory.
+ */
+TEST_F(hmm, migrate)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	int *ptr;
+	int ret;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ | PROT_WRITE,
+			   MAP_PRIVATE | MAP_ANONYMOUS,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	/* Initialize buffer in system memory. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+
+	/* Migrate memory to device. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_MIGRATE, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+
+	/* Check what the device read. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Migrate anonymous memory to device private memory and fault some of it back
+ * to system memory, then try migrating the resulting mix of system and device
+ * private memory to the device.
+ */
+TEST_F(hmm, migrate_fault)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	int *ptr;
+	int ret;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ | PROT_WRITE,
+			   MAP_PRIVATE | MAP_ANONYMOUS,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	/* Initialize buffer in system memory. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+
+	/* Migrate memory to device. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_MIGRATE, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+
+	/* Check what the device read. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	/* Fault half the pages back to system memory and check them. */
+	for (i = 0, ptr = buffer->ptr; i < size / (2 * sizeof(*ptr)); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	/* Migrate memory to the device again. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_MIGRATE, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+
+	/* Check what the device read. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Migrate anonymous shared memory to device private memory.
+ */
+TEST_F(hmm, migrate_shared)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	int ret;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ | PROT_WRITE,
+			   MAP_SHARED | MAP_ANONYMOUS,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	/* Migrate memory to device. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_MIGRATE, buffer, npages);
+	ASSERT_EQ(ret, -ENOENT);
+
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Try to migrate various memory types to device private memory.
+ */
+TEST_F(hmm2, migrate_mixed)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	int *ptr;
+	unsigned char *p;
+	int ret;
+	int val;
+
+	npages = 6;
+	size = npages << self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(size);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	/* Reserve a range of addresses. */
+	buffer->ptr = mmap(NULL, size,
+			   PROT_NONE,
+			   MAP_PRIVATE | MAP_ANONYMOUS,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+	p = buffer->ptr;
+
+	/* Migrating a protected area should be an error. */
+	ret = hmm_dmirror_cmd(self->fd1, HMM_DMIRROR_MIGRATE, buffer, npages);
+	ASSERT_EQ(ret, -EINVAL);
+
+	/* Punch a hole after the first page address. */
+	ret = munmap(buffer->ptr + self->page_size, self->page_size);
+	ASSERT_EQ(ret, 0);
+
+	/* We expect an error if the vma doesn't cover the range. */
+	ret = hmm_dmirror_cmd(self->fd1, HMM_DMIRROR_MIGRATE, buffer, 3);
+	ASSERT_EQ(ret, -EINVAL);
+
+	/* Page 2 will be a read-only zero page. */
+	ret = mprotect(buffer->ptr + 2 * self->page_size, self->page_size,
+				PROT_READ);
+	ASSERT_EQ(ret, 0);
+	ptr = (int *)(buffer->ptr + 2 * self->page_size);
+	val = *ptr + 3;
+	ASSERT_EQ(val, 3);
+
+	/* Page 3 will be read-only. */
+	ret = mprotect(buffer->ptr + 3 * self->page_size, self->page_size,
+				PROT_READ | PROT_WRITE);
+	ASSERT_EQ(ret, 0);
+	ptr = (int *)(buffer->ptr + 3 * self->page_size);
+	*ptr = val;
+	ret = mprotect(buffer->ptr + 3 * self->page_size, self->page_size,
+				PROT_READ);
+	ASSERT_EQ(ret, 0);
+
+	/* Page 4-5 will be read-write. */
+	ret = mprotect(buffer->ptr + 4 * self->page_size, 2 * self->page_size,
+				PROT_READ | PROT_WRITE);
+	ASSERT_EQ(ret, 0);
+	ptr = (int *)(buffer->ptr + 4 * self->page_size);
+	*ptr = val;
+	ptr = (int *)(buffer->ptr + 5 * self->page_size);
+	*ptr = val;
+
+	/* Now try to migrate pages 2-5 to device 1. */
+	buffer->ptr = p + 2 * self->page_size;
+	ret = hmm_dmirror_cmd(self->fd1, HMM_DMIRROR_MIGRATE, buffer, 4);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, 4);
+
+	/* Page 5 won't be migrated to device 0 because it's on device 1. */
+	buffer->ptr = p + 5 * self->page_size;
+	ret = hmm_dmirror_cmd(self->fd0, HMM_DMIRROR_MIGRATE, buffer, 1);
+	ASSERT_EQ(ret, -ENOENT);
+	buffer->ptr = p;
+
+	buffer->ptr = p;
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Migrate anonymous memory to device private memory and fault it back to system
+ * memory multiple times.
+ */
+TEST_F(hmm, migrate_multiple)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	unsigned long c;
+	int *ptr;
+	int ret;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	for (c = 0; c < NTIMES; c++) {
+		buffer = malloc(sizeof(*buffer));
+		ASSERT_NE(buffer, NULL);
+
+		buffer->fd = -1;
+		buffer->size = size;
+		buffer->mirror = malloc(size);
+		ASSERT_NE(buffer->mirror, NULL);
+
+		buffer->ptr = mmap(NULL, size,
+				   PROT_READ | PROT_WRITE,
+				   MAP_PRIVATE | MAP_ANONYMOUS,
+				   buffer->fd, 0);
+		ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+		/* Initialize buffer in system memory. */
+		for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+			ptr[i] = i;
+
+		/* Migrate memory to device. */
+		ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_MIGRATE, buffer,
+				      npages);
+		ASSERT_EQ(ret, 0);
+		ASSERT_EQ(buffer->cpages, npages);
+
+		/* Check what the device read. */
+		for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+			ASSERT_EQ(ptr[i], i);
+
+		/* Fault pages back to system memory and check them. */
+		for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+			ASSERT_EQ(ptr[i], i);
+
+		hmm_buffer_free(buffer);
+	}
+}
+
+/*
+ * Read anonymous memory multiple times.
+ */
+TEST_F(hmm, anon_read_multiple)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	unsigned long c;
+	int *ptr;
+	int ret;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	for (c = 0; c < NTIMES; c++) {
+		buffer = malloc(sizeof(*buffer));
+		ASSERT_NE(buffer, NULL);
+
+		buffer->fd = -1;
+		buffer->size = size;
+		buffer->mirror = malloc(size);
+		ASSERT_NE(buffer->mirror, NULL);
+
+		buffer->ptr = mmap(NULL, size,
+				   PROT_READ | PROT_WRITE,
+				   MAP_PRIVATE | MAP_ANONYMOUS,
+				   buffer->fd, 0);
+		ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+		/* Initialize buffer in system memory. */
+		for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+			ptr[i] = i + c;
+
+		/* Simulate a device reading system memory. */
+		ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_READ, buffer,
+				      npages);
+		ASSERT_EQ(ret, 0);
+		ASSERT_EQ(buffer->cpages, npages);
+		ASSERT_EQ(buffer->faults, 1);
+
+		/* Check what the device read. */
+		for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+			ASSERT_EQ(ptr[i], i + c);
+
+		hmm_buffer_free(buffer);
+	}
+}
+
+void *unmap_buffer(void *p)
+{
+	struct hmm_buffer *buffer = p;
+
+	/* Delay for a bit and then unmap buffer while it is being read. */
+	hmm_nanosleep(hmm_random() % 32000);
+	munmap(buffer->ptr + buffer->size / 2, buffer->size / 2);
+	buffer->ptr = NULL;
+
+	return NULL;
+}
+
+/*
+ * Try reading anonymous memory while it is being unmapped.
+ */
+TEST_F(hmm, anon_teardown)
+{
+	unsigned long npages;
+	unsigned long size;
+	unsigned long c;
+	void *ret;
+
+	npages = ALIGN(HMM_BUFFER_SIZE, self->page_size) >> self->page_shift;
+	ASSERT_NE(npages, 0);
+	size = npages << self->page_shift;
+
+	for (c = 0; c < NTIMES; ++c) {
+		pthread_t thread;
+		struct hmm_buffer *buffer;
+		unsigned long i;
+		int *ptr;
+		int rc;
+
+		buffer = malloc(sizeof(*buffer));
+		ASSERT_NE(buffer, NULL);
+
+		buffer->fd = -1;
+		buffer->size = size;
+		buffer->mirror = malloc(size);
+		ASSERT_NE(buffer->mirror, NULL);
+
+		buffer->ptr = mmap(NULL, size,
+				   PROT_READ | PROT_WRITE,
+				   MAP_PRIVATE | MAP_ANONYMOUS,
+				   buffer->fd, 0);
+		ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+		/* Initialize buffer in system memory. */
+		for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+			ptr[i] = i + c;
+
+		rc = pthread_create(&thread, NULL, unmap_buffer, buffer);
+		ASSERT_EQ(rc, 0);
+
+		/* Simulate a device reading system memory. */
+		rc = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_READ, buffer,
+				     npages);
+		if (rc == 0) {
+			ASSERT_EQ(buffer->cpages, npages);
+			ASSERT_EQ(buffer->faults, 1);
+
+			/* Check what the device read. */
+			for (i = 0, ptr = buffer->mirror;
+			     i < size / sizeof(*ptr);
+			     ++i)
+				ASSERT_EQ(ptr[i], i + c);
+		}
+
+		pthread_join(thread, &ret);
+		hmm_buffer_free(buffer);
+	}
+}
+
+/*
+ * Test memory snapshot without faulting in pages accessed by the device.
+ */
+TEST_F(hmm, mixedmap)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned char *m;
+	int ret;
+
+	npages = 1;
+	size = npages << self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(npages);
+	ASSERT_NE(buffer->mirror, NULL);
+
+
+	/* Reserve a range of addresses. */
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ | PROT_WRITE,
+			   MAP_PRIVATE,
+			   self->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	/* Simulate a device snapshotting CPU pagetables. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_SNAPSHOT, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+
+	/* Check what the device saw. */
+	m = buffer->mirror;
+	ASSERT_EQ(m[0], HMM_DMIRROR_PROT_READ);
+
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Test memory snapshot without faulting in pages accessed by the device.
+ */
+TEST_F(hmm2, snapshot)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	int *ptr;
+	unsigned char *p;
+	unsigned char *m;
+	int ret;
+	int val;
+
+	npages = 7;
+	size = npages << self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(npages);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	/* Reserve a range of addresses. */
+	buffer->ptr = mmap(NULL, size,
+			   PROT_NONE,
+			   MAP_PRIVATE | MAP_ANONYMOUS,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+	p = buffer->ptr;
+
+	/* Punch a hole after the first page address. */
+	ret = munmap(buffer->ptr + self->page_size, self->page_size);
+	ASSERT_EQ(ret, 0);
+
+	/* Page 2 will be read-only zero page. */
+	ret = mprotect(buffer->ptr + 2 * self->page_size, self->page_size,
+				PROT_READ);
+	ASSERT_EQ(ret, 0);
+	ptr = (int *)(buffer->ptr + 2 * self->page_size);
+	val = *ptr + 3;
+	ASSERT_EQ(val, 3);
+
+	/* Page 3 will be read-only. */
+	ret = mprotect(buffer->ptr + 3 * self->page_size, self->page_size,
+				PROT_READ | PROT_WRITE);
+	ASSERT_EQ(ret, 0);
+	ptr = (int *)(buffer->ptr + 3 * self->page_size);
+	*ptr = val;
+	ret = mprotect(buffer->ptr + 3 * self->page_size, self->page_size,
+				PROT_READ);
+	ASSERT_EQ(ret, 0);
+
+	/* Page 4-6 will be read-write. */
+	ret = mprotect(buffer->ptr + 4 * self->page_size, 3 * self->page_size,
+				PROT_READ | PROT_WRITE);
+	ASSERT_EQ(ret, 0);
+	ptr = (int *)(buffer->ptr + 4 * self->page_size);
+	*ptr = val;
+
+	/* Page 5 will be migrated to device 0. */
+	buffer->ptr = p + 5 * self->page_size;
+	ret = hmm_dmirror_cmd(self->fd0, HMM_DMIRROR_MIGRATE, buffer, 1);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, 1);
+
+	/* Page 6 will be migrated to device 1. */
+	buffer->ptr = p + 6 * self->page_size;
+	ret = hmm_dmirror_cmd(self->fd1, HMM_DMIRROR_MIGRATE, buffer, 1);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, 1);
+
+	/* Simulate a device snapshotting CPU pagetables. */
+	buffer->ptr = p;
+	ret = hmm_dmirror_cmd(self->fd0, HMM_DMIRROR_SNAPSHOT, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+
+	/* Check what the device saw. */
+	m = buffer->mirror;
+	ASSERT_EQ(m[0], HMM_DMIRROR_PROT_ERROR);
+	ASSERT_EQ(m[1], HMM_DMIRROR_PROT_ERROR);
+	ASSERT_EQ(m[2], HMM_DMIRROR_PROT_ZERO | HMM_DMIRROR_PROT_READ);
+	ASSERT_EQ(m[3], HMM_DMIRROR_PROT_READ);
+	ASSERT_EQ(m[4], HMM_DMIRROR_PROT_WRITE);
+	ASSERT_EQ(m[5], HMM_DMIRROR_PROT_DEV_PRIVATE_LOCAL |
+			HMM_DMIRROR_PROT_WRITE);
+	ASSERT_EQ(m[6], HMM_DMIRROR_PROT_NONE);
+
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Test the hmm_range_fault() HMM_PFN_PMD flag for large pages that
+ * should be mapped by a large page table entry.
+ */
+TEST_F(hmm, compound)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	int *ptr;
+	unsigned char *m;
+	int ret;
+	long pagesizes[4];
+	int n, idx;
+	unsigned long i;
+
+	/* Skip test if we can't allocate a hugetlbfs page. */
+
+	n = gethugepagesizes(pagesizes, 4);
+	if (n <= 0)
+		return;
+	for (idx = 0; --n > 0; ) {
+		if (pagesizes[n] < pagesizes[idx])
+			idx = n;
+	}
+	size = ALIGN(TWOMEG, pagesizes[idx]);
+	npages = size >> self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->ptr = get_hugepage_region(size, GHR_STRICT);
+	if (buffer->ptr == NULL) {
+		free(buffer);
+		return;
+	}
+
+	buffer->size = size;
+	buffer->mirror = malloc(npages);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	/* Initialize the pages the device will snapshot in buffer->ptr. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+
+	/* Simulate a device snapshotting CPU pagetables. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_SNAPSHOT, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+
+	/* Check what the device saw. */
+	m = buffer->mirror;
+	for (i = 0; i < npages; ++i)
+		ASSERT_EQ(m[i], HMM_DMIRROR_PROT_WRITE |
+				HMM_DMIRROR_PROT_PMD);
+
+	/* Make the region read-only. */
+	ret = mprotect(buffer->ptr, size, PROT_READ);
+	ASSERT_EQ(ret, 0);
+
+	/* Simulate a device snapshotting CPU pagetables. */
+	ret = hmm_dmirror_cmd(self->fd, HMM_DMIRROR_SNAPSHOT, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+
+	/* Check what the device saw. */
+	m = buffer->mirror;
+	for (i = 0; i < npages; ++i)
+		ASSERT_EQ(m[i], HMM_DMIRROR_PROT_READ |
+				HMM_DMIRROR_PROT_PMD);
+
+	free_hugepage_region(buffer->ptr);
+	buffer->ptr = NULL;
+	hmm_buffer_free(buffer);
+}
+
+/*
+ * Test two devices reading the same memory (double mapped).
+ */
+TEST_F(hmm2, double_map)
+{
+	struct hmm_buffer *buffer;
+	unsigned long npages;
+	unsigned long size;
+	unsigned long i;
+	int *ptr;
+	int ret;
+
+	npages = 6;
+	size = npages << self->page_shift;
+
+	buffer = malloc(sizeof(*buffer));
+	ASSERT_NE(buffer, NULL);
+
+	buffer->fd = -1;
+	buffer->size = size;
+	buffer->mirror = malloc(npages);
+	ASSERT_NE(buffer->mirror, NULL);
+
+	/* Reserve a range of addresses. */
+	buffer->ptr = mmap(NULL, size,
+			   PROT_READ | PROT_WRITE,
+			   MAP_PRIVATE | MAP_ANONYMOUS,
+			   buffer->fd, 0);
+	ASSERT_NE(buffer->ptr, MAP_FAILED);
+
+	/* Initialize buffer in system memory. */
+	for (i = 0, ptr = buffer->ptr; i < size / sizeof(*ptr); ++i)
+		ptr[i] = i;
+
+	/* Make region read-only. */
+	ret = mprotect(buffer->ptr, size, PROT_READ);
+	ASSERT_EQ(ret, 0);
+
+	/* Simulate device 0 reading system memory. */
+	ret = hmm_dmirror_cmd(self->fd0, HMM_DMIRROR_READ, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+	ASSERT_EQ(buffer->faults, 1);
+
+	/* Check what the device read. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	/* Simulate device 1 reading system memory. */
+	ret = hmm_dmirror_cmd(self->fd1, HMM_DMIRROR_READ, buffer, npages);
+	ASSERT_EQ(ret, 0);
+	ASSERT_EQ(buffer->cpages, npages);
+	ASSERT_EQ(buffer->faults, 1);
+
+	/* Check what the device read. */
+	for (i = 0, ptr = buffer->mirror; i < size / sizeof(*ptr); ++i)
+		ASSERT_EQ(ptr[i], i);
+
+	/* Punch a hole after the first page address. */
+	ret = munmap(buffer->ptr + self->page_size, self->page_size);
+	ASSERT_EQ(ret, 0);
+
+	hmm_buffer_free(buffer);
+}
+
+TEST_HARNESS_MAIN
diff --git a/tools/testing/selftests/vm/hugetlb_reparenting_test.sh b/tools/testing/selftests/vm/hugetlb_reparenting_test.sh
new file mode 100644
index 0000000..d11d1fe
--- /dev/null
+++ b/tools/testing/selftests/vm/hugetlb_reparenting_test.sh
@@ -0,0 +1,244 @@
+#!/bin/bash
+# SPDX-License-Identifier: GPL-2.0
+
+set -e
+
+if [[ $(id -u) -ne 0 ]]; then
+  echo "This test must be run as root. Skipping..."
+  exit 0
+fi
+
+usage_file=usage_in_bytes
+
+if [[ "$1" == "-cgroup-v2" ]]; then
+  cgroup2=1
+  usage_file=current
+fi
+
+CGROUP_ROOT='/dev/cgroup/memory'
+MNT='/mnt/huge/'
+
+if [[ ! -e $CGROUP_ROOT ]]; then
+  mkdir -p $CGROUP_ROOT
+  if [[ $cgroup2 ]]; then
+    mount -t cgroup2 none $CGROUP_ROOT
+    sleep 1
+    echo "+hugetlb +memory" >$CGROUP_ROOT/cgroup.subtree_control
+  else
+    mount -t cgroup memory,hugetlb $CGROUP_ROOT
+  fi
+fi
+
+function get_machine_hugepage_size() {
+  hpz=$(grep -i hugepagesize /proc/meminfo)
+  kb=${hpz:14:-3}
+  mb=$(($kb / 1024))
+  echo $mb
+}
+
+MB=$(get_machine_hugepage_size)
+
+function cleanup() {
+  echo cleanup
+  set +e
+  rm -rf "$MNT"/* 2>/dev/null
+  umount "$MNT" 2>/dev/null
+  rmdir "$MNT" 2>/dev/null
+  rmdir "$CGROUP_ROOT"/a/b 2>/dev/null
+  rmdir "$CGROUP_ROOT"/a 2>/dev/null
+  rmdir "$CGROUP_ROOT"/test1 2>/dev/null
+  echo 0 >/proc/sys/vm/nr_hugepages
+  set -e
+}
+
+function assert_state() {
+  local expected_a="$1"
+  local expected_a_hugetlb="$2"
+  local expected_b=""
+  local expected_b_hugetlb=""
+
+  if [ ! -z ${3:-} ] && [ ! -z ${4:-} ]; then
+    expected_b="$3"
+    expected_b_hugetlb="$4"
+  fi
+  local tolerance=$((5 * 1024 * 1024))
+
+  local actual_a
+  actual_a="$(cat "$CGROUP_ROOT"/a/memory.$usage_file)"
+  if [[ $actual_a -lt $(($expected_a - $tolerance)) ]] ||
+    [[ $actual_a -gt $(($expected_a + $tolerance)) ]]; then
+    echo actual a = $((${actual_a%% *} / 1024 / 1024)) MB
+    echo expected a = $((${expected_a%% *} / 1024 / 1024)) MB
+    echo fail
+
+    cleanup
+    exit 1
+  fi
+
+  local actual_a_hugetlb
+  actual_a_hugetlb="$(cat "$CGROUP_ROOT"/a/hugetlb.${MB}MB.$usage_file)"
+  if [[ $actual_a_hugetlb -lt $(($expected_a_hugetlb - $tolerance)) ]] ||
+    [[ $actual_a_hugetlb -gt $(($expected_a_hugetlb + $tolerance)) ]]; then
+    echo actual a hugetlb = $((${actual_a_hugetlb%% *} / 1024 / 1024)) MB
+    echo expected a hugetlb = $((${expected_a_hugetlb%% *} / 1024 / 1024)) MB
+    echo fail
+
+    cleanup
+    exit 1
+  fi
+
+  if [[ -z "$expected_b" || -z "$expected_b_hugetlb" ]]; then
+    return
+  fi
+
+  local actual_b
+  actual_b="$(cat "$CGROUP_ROOT"/a/b/memory.$usage_file)"
+  if [[ $actual_b -lt $(($expected_b - $tolerance)) ]] ||
+    [[ $actual_b -gt $(($expected_b + $tolerance)) ]]; then
+    echo actual b = $((${actual_b%% *} / 1024 / 1024)) MB
+    echo expected b = $((${expected_b%% *} / 1024 / 1024)) MB
+    echo fail
+
+    cleanup
+    exit 1
+  fi
+
+  local actual_b_hugetlb
+  actual_b_hugetlb="$(cat "$CGROUP_ROOT"/a/b/hugetlb.${MB}MB.$usage_file)"
+  if [[ $actual_b_hugetlb -lt $(($expected_b_hugetlb - $tolerance)) ]] ||
+    [[ $actual_b_hugetlb -gt $(($expected_b_hugetlb + $tolerance)) ]]; then
+    echo actual b hugetlb = $((${actual_b_hugetlb%% *} / 1024 / 1024)) MB
+    echo expected b hugetlb = $((${expected_b_hugetlb%% *} / 1024 / 1024)) MB
+    echo fail
+
+    cleanup
+    exit 1
+  fi
+}
+
+function setup() {
+  echo 100 >/proc/sys/vm/nr_hugepages
+  mkdir "$CGROUP_ROOT"/a
+  sleep 1
+  if [[ $cgroup2 ]]; then
+    echo "+hugetlb +memory" >$CGROUP_ROOT/a/cgroup.subtree_control
+  else
+    echo 0 >$CGROUP_ROOT/a/cpuset.mems
+    echo 0 >$CGROUP_ROOT/a/cpuset.cpus
+  fi
+
+  mkdir "$CGROUP_ROOT"/a/b
+
+  if [[ ! $cgroup2 ]]; then
+    echo 0 >$CGROUP_ROOT/a/b/cpuset.mems
+    echo 0 >$CGROUP_ROOT/a/b/cpuset.cpus
+  fi
+
+  mkdir -p "$MNT"
+  mount -t hugetlbfs none "$MNT"
+}
+
+write_hugetlbfs() {
+  local cgroup="$1"
+  local path="$2"
+  local size="$3"
+
+  if [[ $cgroup2 ]]; then
+    echo $$ >$CGROUP_ROOT/$cgroup/cgroup.procs
+  else
+    echo 0 >$CGROUP_ROOT/$cgroup/cpuset.mems
+    echo 0 >$CGROUP_ROOT/$cgroup/cpuset.cpus
+    echo $$ >"$CGROUP_ROOT/$cgroup/tasks"
+  fi
+  ./write_to_hugetlbfs -p "$path" -s "$size" -m 0 -o
+  if [[ $cgroup2 ]]; then
+    echo $$ >$CGROUP_ROOT/cgroup.procs
+  else
+    echo $$ >"$CGROUP_ROOT/tasks"
+  fi
+  echo
+}
+
+set -e
+
+size=$((${MB} * 1024 * 1024 * 25)) # 50MB = 25 * 2MB hugepages.
+
+cleanup
+
+echo
+echo
+echo Test charge, rmdir, uncharge
+setup
+echo mkdir
+mkdir $CGROUP_ROOT/test1
+
+echo write
+write_hugetlbfs test1 "$MNT"/test $size
+
+echo rmdir
+rmdir $CGROUP_ROOT/test1
+mkdir $CGROUP_ROOT/test1
+
+echo uncharge
+rm -rf /mnt/huge/*
+
+cleanup
+
+echo done
+echo
+echo
+if [[ ! $cgroup2 ]]; then
+  echo "Test parent and child hugetlb usage"
+  setup
+
+  echo write
+  write_hugetlbfs a "$MNT"/test $size
+
+  echo Assert memory charged correctly for parent use.
+  assert_state 0 $size 0 0
+
+  write_hugetlbfs a/b "$MNT"/test2 $size
+
+  echo Assert memory charged correctly for child use.
+  assert_state 0 $(($size * 2)) 0 $size
+
+  rmdir "$CGROUP_ROOT"/a/b
+  sleep 5
+  echo Assert memory reparent correctly.
+  assert_state 0 $(($size * 2))
+
+  rm -rf "$MNT"/*
+  umount "$MNT"
+  echo Assert memory uncharged correctly.
+  assert_state 0 0
+
+  cleanup
+fi
+
+echo
+echo
+echo "Test child only hugetlb usage"
+echo setup
+setup
+
+echo write
+write_hugetlbfs a/b "$MNT"/test2 $size
+
+echo Assert memory charged correctly for child only use.
+assert_state 0 $(($size)) 0 $size
+
+rmdir "$CGROUP_ROOT"/a/b
+echo Assert memory reparent correctly.
+assert_state 0 $size
+
+rm -rf "$MNT"/*
+umount "$MNT"
+echo Assert memory uncharged correctly.
+assert_state 0 0
+
+cleanup
+
+echo ALL PASS
+
+umount $CGROUP_ROOT
+rm -rf $CGROUP_ROOT
diff --git a/tools/testing/selftests/vm/khugepaged.c b/tools/testing/selftests/vm/khugepaged.c
new file mode 100644
index 0000000..8b75821
--- /dev/null
+++ b/tools/testing/selftests/vm/khugepaged.c
@@ -0,0 +1,1035 @@
+#define _GNU_SOURCE
+#include <fcntl.h>
+#include <limits.h>
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <stdbool.h>
+#include <string.h>
+#include <unistd.h>
+
+#include <sys/mman.h>
+#include <sys/wait.h>
+
+#ifndef MADV_PAGEOUT
+#define MADV_PAGEOUT 21
+#endif
+
+#define BASE_ADDR ((void *)(1UL << 30))
+static unsigned long hpage_pmd_size;
+static unsigned long page_size;
+static int hpage_pmd_nr;
+
+#define THP_SYSFS "/sys/kernel/mm/transparent_hugepage/"
+#define PID_SMAPS "/proc/self/smaps"
+
+enum thp_enabled {
+	THP_ALWAYS,
+	THP_MADVISE,
+	THP_NEVER,
+};
+
+static const char *thp_enabled_strings[] = {
+	"always",
+	"madvise",
+	"never",
+	NULL
+};
+
+enum thp_defrag {
+	THP_DEFRAG_ALWAYS,
+	THP_DEFRAG_DEFER,
+	THP_DEFRAG_DEFER_MADVISE,
+	THP_DEFRAG_MADVISE,
+	THP_DEFRAG_NEVER,
+};
+
+static const char *thp_defrag_strings[] = {
+	"always",
+	"defer",
+	"defer+madvise",
+	"madvise",
+	"never",
+	NULL
+};
+
+enum shmem_enabled {
+	SHMEM_ALWAYS,
+	SHMEM_WITHIN_SIZE,
+	SHMEM_ADVISE,
+	SHMEM_NEVER,
+	SHMEM_DENY,
+	SHMEM_FORCE,
+};
+
+static const char *shmem_enabled_strings[] = {
+	"always",
+	"within_size",
+	"advise",
+	"never",
+	"deny",
+	"force",
+	NULL
+};
+
+struct khugepaged_settings {
+	bool defrag;
+	unsigned int alloc_sleep_millisecs;
+	unsigned int scan_sleep_millisecs;
+	unsigned int max_ptes_none;
+	unsigned int max_ptes_swap;
+	unsigned int max_ptes_shared;
+	unsigned long pages_to_scan;
+};
+
+struct settings {
+	enum thp_enabled thp_enabled;
+	enum thp_defrag thp_defrag;
+	enum shmem_enabled shmem_enabled;
+	bool debug_cow;
+	bool use_zero_page;
+	struct khugepaged_settings khugepaged;
+};
+
+static struct settings default_settings = {
+	.thp_enabled = THP_MADVISE,
+	.thp_defrag = THP_DEFRAG_ALWAYS,
+	.shmem_enabled = SHMEM_NEVER,
+	.debug_cow = 0,
+	.use_zero_page = 0,
+	.khugepaged = {
+		.defrag = 1,
+		.alloc_sleep_millisecs = 10,
+		.scan_sleep_millisecs = 10,
+	},
+};
+
+static struct settings saved_settings;
+static bool skip_settings_restore;
+
+static int exit_status;
+
+static void success(const char *msg)
+{
+	printf(" \e[32m%s\e[0m\n", msg);
+}
+
+static void fail(const char *msg)
+{
+	printf(" \e[31m%s\e[0m\n", msg);
+	exit_status++;
+}
+
+static int read_file(const char *path, char *buf, size_t buflen)
+{
+	int fd;
+	ssize_t numread;
+
+	fd = open(path, O_RDONLY);
+	if (fd == -1)
+		return 0;
+
+	numread = read(fd, buf, buflen - 1);
+	if (numread < 1) {
+		close(fd);
+		return 0;
+	}
+
+	buf[numread] = '\0';
+	close(fd);
+
+	return (unsigned int) numread;
+}
+
+static int write_file(const char *path, const char *buf, size_t buflen)
+{
+	int fd;
+	ssize_t numwritten;
+
+	fd = open(path, O_WRONLY);
+	if (fd == -1)
+		return 0;
+
+	numwritten = write(fd, buf, buflen - 1);
+	close(fd);
+	if (numwritten < 1)
+		return 0;
+
+	return (unsigned int) numwritten;
+}
+
+static int read_string(const char *name, const char *strings[])
+{
+	char path[PATH_MAX];
+	char buf[256];
+	char *c;
+	int ret;
+
+	ret = snprintf(path, PATH_MAX, THP_SYSFS "%s", name);
+	if (ret >= PATH_MAX) {
+		printf("%s: Pathname is too long\n", __func__);
+		exit(EXIT_FAILURE);
+	}
+
+	if (!read_file(path, buf, sizeof(buf))) {
+		perror(path);
+		exit(EXIT_FAILURE);
+	}
+
+	c = strchr(buf, '[');
+	if (!c) {
+		printf("%s: Parse failure\n", __func__);
+		exit(EXIT_FAILURE);
+	}
+
+	c++;
+	memmove(buf, c, sizeof(buf) - (c - buf));
+
+	c = strchr(buf, ']');
+	if (!c) {
+		printf("%s: Parse failure\n", __func__);
+		exit(EXIT_FAILURE);
+	}
+	*c = '\0';
+
+	ret = 0;
+	while (strings[ret]) {
+		if (!strcmp(strings[ret], buf))
+			return ret;
+		ret++;
+	}
+
+	printf("Failed to parse %s\n", name);
+	exit(EXIT_FAILURE);
+}
+
+static void write_string(const char *name, const char *val)
+{
+	char path[PATH_MAX];
+	int ret;
+
+	ret = snprintf(path, PATH_MAX, THP_SYSFS "%s", name);
+	if (ret >= PATH_MAX) {
+		printf("%s: Pathname is too long\n", __func__);
+		exit(EXIT_FAILURE);
+	}
+
+	if (!write_file(path, val, strlen(val) + 1)) {
+		perror(path);
+		exit(EXIT_FAILURE);
+	}
+}
+
+static const unsigned long read_num(const char *name)
+{
+	char path[PATH_MAX];
+	char buf[21];
+	int ret;
+
+	ret = snprintf(path, PATH_MAX, THP_SYSFS "%s", name);
+	if (ret >= PATH_MAX) {
+		printf("%s: Pathname is too long\n", __func__);
+		exit(EXIT_FAILURE);
+	}
+
+	ret = read_file(path, buf, sizeof(buf));
+	if (ret < 0) {
+		perror("read_file(read_num)");
+		exit(EXIT_FAILURE);
+	}
+
+	return strtoul(buf, NULL, 10);
+}
+
+static void write_num(const char *name, unsigned long num)
+{
+	char path[PATH_MAX];
+	char buf[21];
+	int ret;
+
+	ret = snprintf(path, PATH_MAX, THP_SYSFS "%s", name);
+	if (ret >= PATH_MAX) {
+		printf("%s: Pathname is too long\n", __func__);
+		exit(EXIT_FAILURE);
+	}
+
+	sprintf(buf, "%ld", num);
+	if (!write_file(path, buf, strlen(buf) + 1)) {
+		perror(path);
+		exit(EXIT_FAILURE);
+	}
+}
+
+static void write_settings(struct settings *settings)
+{
+	struct khugepaged_settings *khugepaged = &settings->khugepaged;
+
+	write_string("enabled", thp_enabled_strings[settings->thp_enabled]);
+	write_string("defrag", thp_defrag_strings[settings->thp_defrag]);
+	write_string("shmem_enabled",
+			shmem_enabled_strings[settings->shmem_enabled]);
+	write_num("debug_cow", settings->debug_cow);
+	write_num("use_zero_page", settings->use_zero_page);
+
+	write_num("khugepaged/defrag", khugepaged->defrag);
+	write_num("khugepaged/alloc_sleep_millisecs",
+			khugepaged->alloc_sleep_millisecs);
+	write_num("khugepaged/scan_sleep_millisecs",
+			khugepaged->scan_sleep_millisecs);
+	write_num("khugepaged/max_ptes_none", khugepaged->max_ptes_none);
+	write_num("khugepaged/max_ptes_swap", khugepaged->max_ptes_swap);
+	write_num("khugepaged/max_ptes_shared", khugepaged->max_ptes_shared);
+	write_num("khugepaged/pages_to_scan", khugepaged->pages_to_scan);
+}
+
+static void restore_settings(int sig)
+{
+	if (skip_settings_restore)
+		goto out;
+
+	printf("Restore THP and khugepaged settings...");
+	write_settings(&saved_settings);
+	success("OK");
+	if (sig)
+		exit(EXIT_FAILURE);
+out:
+	exit(exit_status);
+}
+
+static void save_settings(void)
+{
+	printf("Save THP and khugepaged settings...");
+	saved_settings = (struct settings) {
+		.thp_enabled = read_string("enabled", thp_enabled_strings),
+		.thp_defrag = read_string("defrag", thp_defrag_strings),
+		.shmem_enabled =
+			read_string("shmem_enabled", shmem_enabled_strings),
+		.debug_cow = read_num("debug_cow"),
+		.use_zero_page = read_num("use_zero_page"),
+	};
+	saved_settings.khugepaged = (struct khugepaged_settings) {
+		.defrag = read_num("khugepaged/defrag"),
+		.alloc_sleep_millisecs =
+			read_num("khugepaged/alloc_sleep_millisecs"),
+		.scan_sleep_millisecs =
+			read_num("khugepaged/scan_sleep_millisecs"),
+		.max_ptes_none = read_num("khugepaged/max_ptes_none"),
+		.max_ptes_swap = read_num("khugepaged/max_ptes_swap"),
+		.max_ptes_shared = read_num("khugepaged/max_ptes_shared"),
+		.pages_to_scan = read_num("khugepaged/pages_to_scan"),
+	};
+	success("OK");
+
+	signal(SIGTERM, restore_settings);
+	signal(SIGINT, restore_settings);
+	signal(SIGHUP, restore_settings);
+	signal(SIGQUIT, restore_settings);
+}
+
+static void adjust_settings(void)
+{
+
+	printf("Adjust settings...");
+	write_settings(&default_settings);
+	success("OK");
+}
+
+#define MAX_LINE_LENGTH 500
+
+static bool check_for_pattern(FILE *fp, char *pattern, char *buf)
+{
+	while (fgets(buf, MAX_LINE_LENGTH, fp) != NULL) {
+		if (!strncmp(buf, pattern, strlen(pattern)))
+			return true;
+	}
+	return false;
+}
+
+static bool check_huge(void *addr)
+{
+	bool thp = false;
+	int ret;
+	FILE *fp;
+	char buffer[MAX_LINE_LENGTH];
+	char addr_pattern[MAX_LINE_LENGTH];
+
+	ret = snprintf(addr_pattern, MAX_LINE_LENGTH, "%08lx-",
+		       (unsigned long) addr);
+	if (ret >= MAX_LINE_LENGTH) {
+		printf("%s: Pattern is too long\n", __func__);
+		exit(EXIT_FAILURE);
+	}
+
+
+	fp = fopen(PID_SMAPS, "r");
+	if (!fp) {
+		printf("%s: Failed to open file %s\n", __func__, PID_SMAPS);
+		exit(EXIT_FAILURE);
+	}
+	if (!check_for_pattern(fp, addr_pattern, buffer))
+		goto err_out;
+
+	ret = snprintf(addr_pattern, MAX_LINE_LENGTH, "AnonHugePages:%10ld kB",
+		       hpage_pmd_size >> 10);
+	if (ret >= MAX_LINE_LENGTH) {
+		printf("%s: Pattern is too long\n", __func__);
+		exit(EXIT_FAILURE);
+	}
+	/*
+	 * Fetch the AnonHugePages: in the same block and check whether it got
+	 * the expected number of hugeepages next.
+	 */
+	if (!check_for_pattern(fp, "AnonHugePages:", buffer))
+		goto err_out;
+
+	if (strncmp(buffer, addr_pattern, strlen(addr_pattern)))
+		goto err_out;
+
+	thp = true;
+err_out:
+	fclose(fp);
+	return thp;
+}
+
+
+static bool check_swap(void *addr, unsigned long size)
+{
+	bool swap = false;
+	int ret;
+	FILE *fp;
+	char buffer[MAX_LINE_LENGTH];
+	char addr_pattern[MAX_LINE_LENGTH];
+
+	ret = snprintf(addr_pattern, MAX_LINE_LENGTH, "%08lx-",
+		       (unsigned long) addr);
+	if (ret >= MAX_LINE_LENGTH) {
+		printf("%s: Pattern is too long\n", __func__);
+		exit(EXIT_FAILURE);
+	}
+
+
+	fp = fopen(PID_SMAPS, "r");
+	if (!fp) {
+		printf("%s: Failed to open file %s\n", __func__, PID_SMAPS);
+		exit(EXIT_FAILURE);
+	}
+	if (!check_for_pattern(fp, addr_pattern, buffer))
+		goto err_out;
+
+	ret = snprintf(addr_pattern, MAX_LINE_LENGTH, "Swap:%19ld kB",
+		       size >> 10);
+	if (ret >= MAX_LINE_LENGTH) {
+		printf("%s: Pattern is too long\n", __func__);
+		exit(EXIT_FAILURE);
+	}
+	/*
+	 * Fetch the Swap: in the same block and check whether it got
+	 * the expected number of hugeepages next.
+	 */
+	if (!check_for_pattern(fp, "Swap:", buffer))
+		goto err_out;
+
+	if (strncmp(buffer, addr_pattern, strlen(addr_pattern)))
+		goto err_out;
+
+	swap = true;
+err_out:
+	fclose(fp);
+	return swap;
+}
+
+static void *alloc_mapping(void)
+{
+	void *p;
+
+	p = mmap(BASE_ADDR, hpage_pmd_size, PROT_READ | PROT_WRITE,
+			MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
+	if (p != BASE_ADDR) {
+		printf("Failed to allocate VMA at %p\n", BASE_ADDR);
+		exit(EXIT_FAILURE);
+	}
+
+	return p;
+}
+
+static void fill_memory(int *p, unsigned long start, unsigned long end)
+{
+	int i;
+
+	for (i = start / page_size; i < end / page_size; i++)
+		p[i * page_size / sizeof(*p)] = i + 0xdead0000;
+}
+
+static void validate_memory(int *p, unsigned long start, unsigned long end)
+{
+	int i;
+
+	for (i = start / page_size; i < end / page_size; i++) {
+		if (p[i * page_size / sizeof(*p)] != i + 0xdead0000) {
+			printf("Page %d is corrupted: %#x\n",
+					i, p[i * page_size / sizeof(*p)]);
+			exit(EXIT_FAILURE);
+		}
+	}
+}
+
+#define TICK 500000
+static bool wait_for_scan(const char *msg, char *p)
+{
+	int full_scans;
+	int timeout = 6; /* 3 seconds */
+
+	/* Sanity check */
+	if (check_huge(p)) {
+		printf("Unexpected huge page\n");
+		exit(EXIT_FAILURE);
+	}
+
+	madvise(p, hpage_pmd_size, MADV_HUGEPAGE);
+
+	/* Wait until the second full_scan completed */
+	full_scans = read_num("khugepaged/full_scans") + 2;
+
+	printf("%s...", msg);
+	while (timeout--) {
+		if (check_huge(p))
+			break;
+		if (read_num("khugepaged/full_scans") >= full_scans)
+			break;
+		printf(".");
+		usleep(TICK);
+	}
+
+	madvise(p, hpage_pmd_size, MADV_NOHUGEPAGE);
+
+	return timeout == -1;
+}
+
+static void alloc_at_fault(void)
+{
+	struct settings settings = default_settings;
+	char *p;
+
+	settings.thp_enabled = THP_ALWAYS;
+	write_settings(&settings);
+
+	p = alloc_mapping();
+	*p = 1;
+	printf("Allocate huge page on fault...");
+	if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+
+	write_settings(&default_settings);
+
+	madvise(p, page_size, MADV_DONTNEED);
+	printf("Split huge PMD on MADV_DONTNEED...");
+	if (!check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+	munmap(p, hpage_pmd_size);
+}
+
+static void collapse_full(void)
+{
+	void *p;
+
+	p = alloc_mapping();
+	fill_memory(p, 0, hpage_pmd_size);
+	if (wait_for_scan("Collapse fully populated PTE table", p))
+		fail("Timeout");
+	else if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+	validate_memory(p, 0, hpage_pmd_size);
+	munmap(p, hpage_pmd_size);
+}
+
+static void collapse_empty(void)
+{
+	void *p;
+
+	p = alloc_mapping();
+	if (wait_for_scan("Do not collapse empty PTE table", p))
+		fail("Timeout");
+	else if (check_huge(p))
+		fail("Fail");
+	else
+		success("OK");
+	munmap(p, hpage_pmd_size);
+}
+
+static void collapse_single_pte_entry(void)
+{
+	void *p;
+
+	p = alloc_mapping();
+	fill_memory(p, 0, page_size);
+	if (wait_for_scan("Collapse PTE table with single PTE entry present", p))
+		fail("Timeout");
+	else if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+	validate_memory(p, 0, page_size);
+	munmap(p, hpage_pmd_size);
+}
+
+static void collapse_max_ptes_none(void)
+{
+	int max_ptes_none = hpage_pmd_nr / 2;
+	struct settings settings = default_settings;
+	void *p;
+
+	settings.khugepaged.max_ptes_none = max_ptes_none;
+	write_settings(&settings);
+
+	p = alloc_mapping();
+
+	fill_memory(p, 0, (hpage_pmd_nr - max_ptes_none - 1) * page_size);
+	if (wait_for_scan("Do not collapse with max_ptes_none exceeded", p))
+		fail("Timeout");
+	else if (check_huge(p))
+		fail("Fail");
+	else
+		success("OK");
+	validate_memory(p, 0, (hpage_pmd_nr - max_ptes_none - 1) * page_size);
+
+	fill_memory(p, 0, (hpage_pmd_nr - max_ptes_none) * page_size);
+	if (wait_for_scan("Collapse with max_ptes_none PTEs empty", p))
+		fail("Timeout");
+	else if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+	validate_memory(p, 0, (hpage_pmd_nr - max_ptes_none) * page_size);
+
+	munmap(p, hpage_pmd_size);
+	write_settings(&default_settings);
+}
+
+static void collapse_swapin_single_pte(void)
+{
+	void *p;
+	p = alloc_mapping();
+	fill_memory(p, 0, hpage_pmd_size);
+
+	printf("Swapout one page...");
+	if (madvise(p, page_size, MADV_PAGEOUT)) {
+		perror("madvise(MADV_PAGEOUT)");
+		exit(EXIT_FAILURE);
+	}
+	if (check_swap(p, page_size)) {
+		success("OK");
+	} else {
+		fail("Fail");
+		goto out;
+	}
+
+	if (wait_for_scan("Collapse with swapping in single PTE entry", p))
+		fail("Timeout");
+	else if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+	validate_memory(p, 0, hpage_pmd_size);
+out:
+	munmap(p, hpage_pmd_size);
+}
+
+static void collapse_max_ptes_swap(void)
+{
+	int max_ptes_swap = read_num("khugepaged/max_ptes_swap");
+	void *p;
+
+	p = alloc_mapping();
+
+	fill_memory(p, 0, hpage_pmd_size);
+	printf("Swapout %d of %d pages...", max_ptes_swap + 1, hpage_pmd_nr);
+	if (madvise(p, (max_ptes_swap + 1) * page_size, MADV_PAGEOUT)) {
+		perror("madvise(MADV_PAGEOUT)");
+		exit(EXIT_FAILURE);
+	}
+	if (check_swap(p, (max_ptes_swap + 1) * page_size)) {
+		success("OK");
+	} else {
+		fail("Fail");
+		goto out;
+	}
+
+	if (wait_for_scan("Do not collapse with max_ptes_swap exceeded", p))
+		fail("Timeout");
+	else if (check_huge(p))
+		fail("Fail");
+	else
+		success("OK");
+	validate_memory(p, 0, hpage_pmd_size);
+
+	fill_memory(p, 0, hpage_pmd_size);
+	printf("Swapout %d of %d pages...", max_ptes_swap, hpage_pmd_nr);
+	if (madvise(p, max_ptes_swap * page_size, MADV_PAGEOUT)) {
+		perror("madvise(MADV_PAGEOUT)");
+		exit(EXIT_FAILURE);
+	}
+	if (check_swap(p, max_ptes_swap * page_size)) {
+		success("OK");
+	} else {
+		fail("Fail");
+		goto out;
+	}
+
+	if (wait_for_scan("Collapse with max_ptes_swap pages swapped out", p))
+		fail("Timeout");
+	else if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+	validate_memory(p, 0, hpage_pmd_size);
+out:
+	munmap(p, hpage_pmd_size);
+}
+
+static void collapse_single_pte_entry_compound(void)
+{
+	void *p;
+
+	p = alloc_mapping();
+
+	printf("Allocate huge page...");
+	madvise(p, hpage_pmd_size, MADV_HUGEPAGE);
+	fill_memory(p, 0, hpage_pmd_size);
+	if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+	madvise(p, hpage_pmd_size, MADV_NOHUGEPAGE);
+
+	printf("Split huge page leaving single PTE mapping compound page...");
+	madvise(p + page_size, hpage_pmd_size - page_size, MADV_DONTNEED);
+	if (!check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+
+	if (wait_for_scan("Collapse PTE table with single PTE mapping compound page", p))
+		fail("Timeout");
+	else if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+	validate_memory(p, 0, page_size);
+	munmap(p, hpage_pmd_size);
+}
+
+static void collapse_full_of_compound(void)
+{
+	void *p;
+
+	p = alloc_mapping();
+
+	printf("Allocate huge page...");
+	madvise(p, hpage_pmd_size, MADV_HUGEPAGE);
+	fill_memory(p, 0, hpage_pmd_size);
+	if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+
+	printf("Split huge page leaving single PTE page table full of compound pages...");
+	madvise(p, page_size, MADV_NOHUGEPAGE);
+	madvise(p, hpage_pmd_size, MADV_NOHUGEPAGE);
+	if (!check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+
+	if (wait_for_scan("Collapse PTE table full of compound pages", p))
+		fail("Timeout");
+	else if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+	validate_memory(p, 0, hpage_pmd_size);
+	munmap(p, hpage_pmd_size);
+}
+
+static void collapse_compound_extreme(void)
+{
+	void *p;
+	int i;
+
+	p = alloc_mapping();
+	for (i = 0; i < hpage_pmd_nr; i++) {
+		printf("\rConstruct PTE page table full of different PTE-mapped compound pages %3d/%d...",
+				i + 1, hpage_pmd_nr);
+
+		madvise(BASE_ADDR, hpage_pmd_size, MADV_HUGEPAGE);
+		fill_memory(BASE_ADDR, 0, hpage_pmd_size);
+		if (!check_huge(BASE_ADDR)) {
+			printf("Failed to allocate huge page\n");
+			exit(EXIT_FAILURE);
+		}
+		madvise(BASE_ADDR, hpage_pmd_size, MADV_NOHUGEPAGE);
+
+		p = mremap(BASE_ADDR - i * page_size,
+				i * page_size + hpage_pmd_size,
+				(i + 1) * page_size,
+				MREMAP_MAYMOVE | MREMAP_FIXED,
+				BASE_ADDR + 2 * hpage_pmd_size);
+		if (p == MAP_FAILED) {
+			perror("mremap+unmap");
+			exit(EXIT_FAILURE);
+		}
+
+		p = mremap(BASE_ADDR + 2 * hpage_pmd_size,
+				(i + 1) * page_size,
+				(i + 1) * page_size + hpage_pmd_size,
+				MREMAP_MAYMOVE | MREMAP_FIXED,
+				BASE_ADDR - (i + 1) * page_size);
+		if (p == MAP_FAILED) {
+			perror("mremap+alloc");
+			exit(EXIT_FAILURE);
+		}
+	}
+
+	munmap(BASE_ADDR, hpage_pmd_size);
+	fill_memory(p, 0, hpage_pmd_size);
+	if (!check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+
+	if (wait_for_scan("Collapse PTE table full of different compound pages", p))
+		fail("Timeout");
+	else if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+
+	validate_memory(p, 0, hpage_pmd_size);
+	munmap(p, hpage_pmd_size);
+}
+
+static void collapse_fork(void)
+{
+	int wstatus;
+	void *p;
+
+	p = alloc_mapping();
+
+	printf("Allocate small page...");
+	fill_memory(p, 0, page_size);
+	if (!check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+
+	printf("Share small page over fork()...");
+	if (!fork()) {
+		/* Do not touch settings on child exit */
+		skip_settings_restore = true;
+		exit_status = 0;
+
+		if (!check_huge(p))
+			success("OK");
+		else
+			fail("Fail");
+
+		fill_memory(p, page_size, 2 * page_size);
+
+		if (wait_for_scan("Collapse PTE table with single page shared with parent process", p))
+			fail("Timeout");
+		else if (check_huge(p))
+			success("OK");
+		else
+			fail("Fail");
+
+		validate_memory(p, 0, page_size);
+		munmap(p, hpage_pmd_size);
+		exit(exit_status);
+	}
+
+	wait(&wstatus);
+	exit_status += WEXITSTATUS(wstatus);
+
+	printf("Check if parent still has small page...");
+	if (!check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+	validate_memory(p, 0, page_size);
+	munmap(p, hpage_pmd_size);
+}
+
+static void collapse_fork_compound(void)
+{
+	int wstatus;
+	void *p;
+
+	p = alloc_mapping();
+
+	printf("Allocate huge page...");
+	madvise(p, hpage_pmd_size, MADV_HUGEPAGE);
+	fill_memory(p, 0, hpage_pmd_size);
+	if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+
+	printf("Share huge page over fork()...");
+	if (!fork()) {
+		/* Do not touch settings on child exit */
+		skip_settings_restore = true;
+		exit_status = 0;
+
+		if (check_huge(p))
+			success("OK");
+		else
+			fail("Fail");
+
+		printf("Split huge page PMD in child process...");
+		madvise(p, page_size, MADV_NOHUGEPAGE);
+		madvise(p, hpage_pmd_size, MADV_NOHUGEPAGE);
+		if (!check_huge(p))
+			success("OK");
+		else
+			fail("Fail");
+		fill_memory(p, 0, page_size);
+
+		write_num("khugepaged/max_ptes_shared", hpage_pmd_nr - 1);
+		if (wait_for_scan("Collapse PTE table full of compound pages in child", p))
+			fail("Timeout");
+		else if (check_huge(p))
+			success("OK");
+		else
+			fail("Fail");
+		write_num("khugepaged/max_ptes_shared",
+				default_settings.khugepaged.max_ptes_shared);
+
+		validate_memory(p, 0, hpage_pmd_size);
+		munmap(p, hpage_pmd_size);
+		exit(exit_status);
+	}
+
+	wait(&wstatus);
+	exit_status += WEXITSTATUS(wstatus);
+
+	printf("Check if parent still has huge page...");
+	if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+	validate_memory(p, 0, hpage_pmd_size);
+	munmap(p, hpage_pmd_size);
+}
+
+static void collapse_max_ptes_shared()
+{
+	int max_ptes_shared = read_num("khugepaged/max_ptes_shared");
+	int wstatus;
+	void *p;
+
+	p = alloc_mapping();
+
+	printf("Allocate huge page...");
+	madvise(p, hpage_pmd_size, MADV_HUGEPAGE);
+	fill_memory(p, 0, hpage_pmd_size);
+	if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+
+	printf("Share huge page over fork()...");
+	if (!fork()) {
+		/* Do not touch settings on child exit */
+		skip_settings_restore = true;
+		exit_status = 0;
+
+		if (check_huge(p))
+			success("OK");
+		else
+			fail("Fail");
+
+		printf("Trigger CoW on page %d of %d...",
+				hpage_pmd_nr - max_ptes_shared - 1, hpage_pmd_nr);
+		fill_memory(p, 0, (hpage_pmd_nr - max_ptes_shared - 1) * page_size);
+		if (!check_huge(p))
+			success("OK");
+		else
+			fail("Fail");
+
+		if (wait_for_scan("Do not collapse with max_ptes_shared exceeded", p))
+			fail("Timeout");
+		else if (!check_huge(p))
+			success("OK");
+		else
+			fail("Fail");
+
+		printf("Trigger CoW on page %d of %d...",
+				hpage_pmd_nr - max_ptes_shared, hpage_pmd_nr);
+		fill_memory(p, 0, (hpage_pmd_nr - max_ptes_shared) * page_size);
+		if (!check_huge(p))
+			success("OK");
+		else
+			fail("Fail");
+
+
+		if (wait_for_scan("Collapse with max_ptes_shared PTEs shared", p))
+			fail("Timeout");
+		else if (check_huge(p))
+			success("OK");
+		else
+			fail("Fail");
+
+		validate_memory(p, 0, hpage_pmd_size);
+		munmap(p, hpage_pmd_size);
+		exit(exit_status);
+	}
+
+	wait(&wstatus);
+	exit_status += WEXITSTATUS(wstatus);
+
+	printf("Check if parent still has huge page...");
+	if (check_huge(p))
+		success("OK");
+	else
+		fail("Fail");
+	validate_memory(p, 0, hpage_pmd_size);
+	munmap(p, hpage_pmd_size);
+}
+
+int main(void)
+{
+	setbuf(stdout, NULL);
+
+	page_size = getpagesize();
+	hpage_pmd_size = read_num("hpage_pmd_size");
+	hpage_pmd_nr = hpage_pmd_size / page_size;
+
+	default_settings.khugepaged.max_ptes_none = hpage_pmd_nr - 1;
+	default_settings.khugepaged.max_ptes_swap = hpage_pmd_nr / 8;
+	default_settings.khugepaged.max_ptes_shared = hpage_pmd_nr / 2;
+	default_settings.khugepaged.pages_to_scan = hpage_pmd_nr * 8;
+
+	save_settings();
+	adjust_settings();
+
+	alloc_at_fault();
+	collapse_full();
+	collapse_empty();
+	collapse_single_pte_entry();
+	collapse_max_ptes_none();
+	collapse_swapin_single_pte();
+	collapse_max_ptes_swap();
+	collapse_single_pte_entry_compound();
+	collapse_full_of_compound();
+	collapse_compound_extreme();
+	collapse_fork();
+	collapse_fork_compound();
+	collapse_max_ptes_shared();
+
+	restore_settings(0);
+}
diff --git a/tools/testing/selftests/vm/map_fixed_noreplace.c b/tools/testing/selftests/vm/map_fixed_noreplace.c
index d91bde5..eed4432 100644
--- a/tools/testing/selftests/vm/map_fixed_noreplace.c
+++ b/tools/testing/selftests/vm/map_fixed_noreplace.c
@@ -17,9 +17,6 @@
 #define MAP_FIXED_NOREPLACE 0x100000
 #endif
 
-#define BASE_ADDRESS	(256ul * 1024 * 1024)
-
-
 static void dump_maps(void)
 {
 	char cmd[32];
@@ -28,18 +25,46 @@
 	system(cmd);
 }
 
+static unsigned long find_base_addr(unsigned long size)
+{
+	void *addr;
+	unsigned long flags;
+
+	flags = MAP_PRIVATE | MAP_ANONYMOUS;
+	addr = mmap(NULL, size, PROT_NONE, flags, -1, 0);
+	if (addr == MAP_FAILED) {
+		printf("Error: couldn't map the space we need for the test\n");
+		return 0;
+	}
+
+	if (munmap(addr, size) != 0) {
+		printf("Error: couldn't map the space we need for the test\n");
+		return 0;
+	}
+	return (unsigned long)addr;
+}
+
 int main(void)
 {
+	unsigned long base_addr;
 	unsigned long flags, addr, size, page_size;
 	char *p;
 
 	page_size = sysconf(_SC_PAGE_SIZE);
 
+	//let's find a base addr that is free before we start the tests
+	size = 5 * page_size;
+	base_addr = find_base_addr(size);
+	if (!base_addr) {
+		printf("Error: couldn't map the space we need for the test\n");
+		return 1;
+	}
+
 	flags = MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED_NOREPLACE;
 
 	// Check we can map all the areas we need below
 	errno = 0;
-	addr = BASE_ADDRESS;
+	addr = base_addr;
 	size = 5 * page_size;
 	p = mmap((void *)addr, size, PROT_NONE, flags, -1, 0);
 
@@ -60,7 +85,7 @@
 	printf("unmap() successful\n");
 
 	errno = 0;
-	addr = BASE_ADDRESS + page_size;
+	addr = base_addr + page_size;
 	size = 3 * page_size;
 	p = mmap((void *)addr, size, PROT_NONE, flags, -1, 0);
 	printf("mmap() @ 0x%lx-0x%lx p=%p result=%m\n", addr, addr + size, p);
@@ -80,7 +105,7 @@
 	 *     +4 |  free  | new
 	 */
 	errno = 0;
-	addr = BASE_ADDRESS;
+	addr = base_addr;
 	size = 5 * page_size;
 	p = mmap((void *)addr, size, PROT_NONE, flags, -1, 0);
 	printf("mmap() @ 0x%lx-0x%lx p=%p result=%m\n", addr, addr + size, p);
@@ -101,7 +126,7 @@
 	 *     +4 |  free  |
 	 */
 	errno = 0;
-	addr = BASE_ADDRESS + (2 * page_size);
+	addr = base_addr + (2 * page_size);
 	size = page_size;
 	p = mmap((void *)addr, size, PROT_NONE, flags, -1, 0);
 	printf("mmap() @ 0x%lx-0x%lx p=%p result=%m\n", addr, addr + size, p);
@@ -121,7 +146,7 @@
 	 *     +4 |  free  | new
 	 */
 	errno = 0;
-	addr = BASE_ADDRESS + (3 * page_size);
+	addr = base_addr + (3 * page_size);
 	size = 2 * page_size;
 	p = mmap((void *)addr, size, PROT_NONE, flags, -1, 0);
 	printf("mmap() @ 0x%lx-0x%lx p=%p result=%m\n", addr, addr + size, p);
@@ -141,7 +166,7 @@
 	 *     +4 |  free  |
 	 */
 	errno = 0;
-	addr = BASE_ADDRESS;
+	addr = base_addr;
 	size = 2 * page_size;
 	p = mmap((void *)addr, size, PROT_NONE, flags, -1, 0);
 	printf("mmap() @ 0x%lx-0x%lx p=%p result=%m\n", addr, addr + size, p);
@@ -161,7 +186,7 @@
 	 *     +4 |  free  |
 	 */
 	errno = 0;
-	addr = BASE_ADDRESS;
+	addr = base_addr;
 	size = page_size;
 	p = mmap((void *)addr, size, PROT_NONE, flags, -1, 0);
 	printf("mmap() @ 0x%lx-0x%lx p=%p result=%m\n", addr, addr + size, p);
@@ -181,7 +206,7 @@
 	 *     +4 |  free  |  new
 	 */
 	errno = 0;
-	addr = BASE_ADDRESS + (4 * page_size);
+	addr = base_addr + (4 * page_size);
 	size = page_size;
 	p = mmap((void *)addr, size, PROT_NONE, flags, -1, 0);
 	printf("mmap() @ 0x%lx-0x%lx p=%p result=%m\n", addr, addr + size, p);
@@ -192,7 +217,7 @@
 		return 1;
 	}
 
-	addr = BASE_ADDRESS;
+	addr = base_addr;
 	size = 5 * page_size;
 	if (munmap((void *)addr, size) != 0) {
 		dump_maps();
diff --git a/tools/testing/selftests/vm/mremap_dontunmap.c b/tools/testing/selftests/vm/mremap_dontunmap.c
new file mode 100644
index 0000000..3a7b5ef
--- /dev/null
+++ b/tools/testing/selftests/vm/mremap_dontunmap.c
@@ -0,0 +1,312 @@
+// SPDX-License-Identifier: GPL-2.0
+
+/*
+ * Tests for mremap w/ MREMAP_DONTUNMAP.
+ *
+ * Copyright 2020, Brian Geffon <bgeffon@google.com>
+ */
+#define _GNU_SOURCE
+#include <sys/mman.h>
+#include <errno.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <unistd.h>
+
+#include "../kselftest.h"
+
+#ifndef MREMAP_DONTUNMAP
+#define MREMAP_DONTUNMAP 4
+#endif
+
+unsigned long page_size;
+char *page_buffer;
+
+static void dump_maps(void)
+{
+	char cmd[32];
+
+	snprintf(cmd, sizeof(cmd), "cat /proc/%d/maps", getpid());
+	system(cmd);
+}
+
+#define BUG_ON(condition, description)					      \
+	do {								      \
+		if (condition) {					      \
+			fprintf(stderr, "[FAIL]\t%s():%d\t%s:%s\n", __func__, \
+				__LINE__, (description), strerror(errno));    \
+			dump_maps();					  \
+			exit(1);					      \
+		} 							      \
+	} while (0)
+
+// Try a simple operation for to "test" for kernel support this prevents
+// reporting tests as failed when it's run on an older kernel.
+static int kernel_support_for_mremap_dontunmap()
+{
+	int ret = 0;
+	unsigned long num_pages = 1;
+	void *source_mapping = mmap(NULL, num_pages * page_size, PROT_NONE,
+				    MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+	BUG_ON(source_mapping == MAP_FAILED, "mmap");
+
+	// This simple remap should only fail if MREMAP_DONTUNMAP isn't
+	// supported.
+	void *dest_mapping =
+	    mremap(source_mapping, num_pages * page_size, num_pages * page_size,
+		   MREMAP_DONTUNMAP | MREMAP_MAYMOVE, 0);
+	if (dest_mapping == MAP_FAILED) {
+		ret = errno;
+	} else {
+		BUG_ON(munmap(dest_mapping, num_pages * page_size) == -1,
+		       "unable to unmap destination mapping");
+	}
+
+	BUG_ON(munmap(source_mapping, num_pages * page_size) == -1,
+	       "unable to unmap source mapping");
+	return ret;
+}
+
+// This helper will just validate that an entire mapping contains the expected
+// byte.
+static int check_region_contains_byte(void *addr, unsigned long size, char byte)
+{
+	BUG_ON(size & (page_size - 1),
+	       "check_region_contains_byte expects page multiples");
+	BUG_ON((unsigned long)addr & (page_size - 1),
+	       "check_region_contains_byte expects page alignment");
+
+	memset(page_buffer, byte, page_size);
+
+	unsigned long num_pages = size / page_size;
+	unsigned long i;
+
+	// Compare each page checking that it contains our expected byte.
+	for (i = 0; i < num_pages; ++i) {
+		int ret =
+		    memcmp(addr + (i * page_size), page_buffer, page_size);
+		if (ret) {
+			return ret;
+		}
+	}
+
+	return 0;
+}
+
+// this test validates that MREMAP_DONTUNMAP moves the pagetables while leaving
+// the source mapping mapped.
+static void mremap_dontunmap_simple()
+{
+	unsigned long num_pages = 5;
+
+	void *source_mapping =
+	    mmap(NULL, num_pages * page_size, PROT_READ | PROT_WRITE,
+		 MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+	BUG_ON(source_mapping == MAP_FAILED, "mmap");
+
+	memset(source_mapping, 'a', num_pages * page_size);
+
+	// Try to just move the whole mapping anywhere (not fixed).
+	void *dest_mapping =
+	    mremap(source_mapping, num_pages * page_size, num_pages * page_size,
+		   MREMAP_DONTUNMAP | MREMAP_MAYMOVE, NULL);
+	BUG_ON(dest_mapping == MAP_FAILED, "mremap");
+
+	// Validate that the pages have been moved, we know they were moved if
+	// the dest_mapping contains a's.
+	BUG_ON(check_region_contains_byte
+	       (dest_mapping, num_pages * page_size, 'a') != 0,
+	       "pages did not migrate");
+	BUG_ON(check_region_contains_byte
+	       (source_mapping, num_pages * page_size, 0) != 0,
+	       "source should have no ptes");
+
+	BUG_ON(munmap(dest_mapping, num_pages * page_size) == -1,
+	       "unable to unmap destination mapping");
+	BUG_ON(munmap(source_mapping, num_pages * page_size) == -1,
+	       "unable to unmap source mapping");
+}
+
+// This test validates MREMAP_DONTUNMAP will move page tables to a specific
+// destination using MREMAP_FIXED, also while validating that the source
+// remains intact.
+static void mremap_dontunmap_simple_fixed()
+{
+	unsigned long num_pages = 5;
+
+	// Since we want to guarantee that we can remap to a point, we will
+	// create a mapping up front.
+	void *dest_mapping =
+	    mmap(NULL, num_pages * page_size, PROT_READ | PROT_WRITE,
+		 MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+	BUG_ON(dest_mapping == MAP_FAILED, "mmap");
+	memset(dest_mapping, 'X', num_pages * page_size);
+
+	void *source_mapping =
+	    mmap(NULL, num_pages * page_size, PROT_READ | PROT_WRITE,
+		 MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+	BUG_ON(source_mapping == MAP_FAILED, "mmap");
+	memset(source_mapping, 'a', num_pages * page_size);
+
+	void *remapped_mapping =
+	    mremap(source_mapping, num_pages * page_size, num_pages * page_size,
+		   MREMAP_FIXED | MREMAP_DONTUNMAP | MREMAP_MAYMOVE,
+		   dest_mapping);
+	BUG_ON(remapped_mapping == MAP_FAILED, "mremap");
+	BUG_ON(remapped_mapping != dest_mapping,
+	       "mremap should have placed the remapped mapping at dest_mapping");
+
+	// The dest mapping will have been unmap by mremap so we expect the Xs
+	// to be gone and replaced with a's.
+	BUG_ON(check_region_contains_byte
+	       (dest_mapping, num_pages * page_size, 'a') != 0,
+	       "pages did not migrate");
+
+	// And the source mapping will have had its ptes dropped.
+	BUG_ON(check_region_contains_byte
+	       (source_mapping, num_pages * page_size, 0) != 0,
+	       "source should have no ptes");
+
+	BUG_ON(munmap(dest_mapping, num_pages * page_size) == -1,
+	       "unable to unmap destination mapping");
+	BUG_ON(munmap(source_mapping, num_pages * page_size) == -1,
+	       "unable to unmap source mapping");
+}
+
+// This test validates that we can MREMAP_DONTUNMAP for a portion of an
+// existing mapping.
+static void mremap_dontunmap_partial_mapping()
+{
+	/*
+	 *  source mapping:
+	 *  --------------
+	 *  | aaaaaaaaaa |
+	 *  --------------
+	 *  to become:
+	 *  --------------
+	 *  | aaaaa00000 |
+	 *  --------------
+	 *  With the destination mapping containing 5 pages of As.
+	 *  ---------
+	 *  | aaaaa |
+	 *  ---------
+	 */
+	unsigned long num_pages = 10;
+	void *source_mapping =
+	    mmap(NULL, num_pages * page_size, PROT_READ | PROT_WRITE,
+		 MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+	BUG_ON(source_mapping == MAP_FAILED, "mmap");
+	memset(source_mapping, 'a', num_pages * page_size);
+
+	// We will grab the last 5 pages of the source and move them.
+	void *dest_mapping =
+	    mremap(source_mapping + (5 * page_size), 5 * page_size,
+		   5 * page_size,
+		   MREMAP_DONTUNMAP | MREMAP_MAYMOVE, NULL);
+	BUG_ON(dest_mapping == MAP_FAILED, "mremap");
+
+	// We expect the first 5 pages of the source to contain a's and the
+	// final 5 pages to contain zeros.
+	BUG_ON(check_region_contains_byte(source_mapping, 5 * page_size, 'a') !=
+	       0, "first 5 pages of source should have original pages");
+	BUG_ON(check_region_contains_byte
+	       (source_mapping + (5 * page_size), 5 * page_size, 0) != 0,
+	       "final 5 pages of source should have no ptes");
+
+	// Finally we expect the destination to have 5 pages worth of a's.
+	BUG_ON(check_region_contains_byte(dest_mapping, 5 * page_size, 'a') !=
+	       0, "dest mapping should contain ptes from the source");
+
+	BUG_ON(munmap(dest_mapping, 5 * page_size) == -1,
+	       "unable to unmap destination mapping");
+	BUG_ON(munmap(source_mapping, num_pages * page_size) == -1,
+	       "unable to unmap source mapping");
+}
+
+// This test validates that we can remap over only a portion of a mapping.
+static void mremap_dontunmap_partial_mapping_overwrite(void)
+{
+	/*
+	 *  source mapping:
+	 *  ---------
+	 *  |aaaaa|
+	 *  ---------
+	 *  dest mapping initially:
+	 *  -----------
+	 *  |XXXXXXXXXX|
+	 *  ------------
+	 *  Source to become:
+	 *  ---------
+	 *  |00000|
+	 *  ---------
+	 *  With the destination mapping containing 5 pages of As.
+	 *  ------------
+	 *  |aaaaaXXXXX|
+	 *  ------------
+	 */
+	void *source_mapping =
+	    mmap(NULL, 5 * page_size, PROT_READ | PROT_WRITE,
+		 MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+	BUG_ON(source_mapping == MAP_FAILED, "mmap");
+	memset(source_mapping, 'a', 5 * page_size);
+
+	void *dest_mapping =
+	    mmap(NULL, 10 * page_size, PROT_READ | PROT_WRITE,
+		 MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+	BUG_ON(dest_mapping == MAP_FAILED, "mmap");
+	memset(dest_mapping, 'X', 10 * page_size);
+
+	// We will grab the last 5 pages of the source and move them.
+	void *remapped_mapping =
+	    mremap(source_mapping, 5 * page_size,
+		   5 * page_size,
+		   MREMAP_DONTUNMAP | MREMAP_MAYMOVE | MREMAP_FIXED, dest_mapping);
+	BUG_ON(dest_mapping == MAP_FAILED, "mremap");
+	BUG_ON(dest_mapping != remapped_mapping, "expected to remap to dest_mapping");
+
+	BUG_ON(check_region_contains_byte(source_mapping, 5 * page_size, 0) !=
+	       0, "first 5 pages of source should have no ptes");
+
+	// Finally we expect the destination to have 5 pages worth of a's.
+	BUG_ON(check_region_contains_byte(dest_mapping, 5 * page_size, 'a') != 0,
+			"dest mapping should contain ptes from the source");
+
+	// Finally the last 5 pages shouldn't have been touched.
+	BUG_ON(check_region_contains_byte(dest_mapping + (5 * page_size),
+				5 * page_size, 'X') != 0,
+			"dest mapping should have retained the last 5 pages");
+
+	BUG_ON(munmap(dest_mapping, 10 * page_size) == -1,
+	       "unable to unmap destination mapping");
+	BUG_ON(munmap(source_mapping, 5 * page_size) == -1,
+	       "unable to unmap source mapping");
+}
+
+int main(void)
+{
+	page_size = sysconf(_SC_PAGE_SIZE);
+
+	// test for kernel support for MREMAP_DONTUNMAP skipping the test if
+	// not.
+	if (kernel_support_for_mremap_dontunmap() != 0) {
+		printf("No kernel support for MREMAP_DONTUNMAP\n");
+		return KSFT_SKIP;
+	}
+
+	// Keep a page sized buffer around for when we need it.
+	page_buffer =
+	    mmap(NULL, page_size, PROT_READ | PROT_WRITE,
+		 MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+	BUG_ON(page_buffer == MAP_FAILED, "unable to mmap a page.");
+
+	mremap_dontunmap_simple();
+	mremap_dontunmap_simple_fixed();
+	mremap_dontunmap_partial_mapping();
+	mremap_dontunmap_partial_mapping_overwrite();
+
+	BUG_ON(munmap(page_buffer, page_size) == -1,
+	       "unable to unmap page buffer");
+
+	printf("OK\n");
+	return 0;
+}
diff --git a/tools/testing/selftests/vm/pkey-helpers.h b/tools/testing/selftests/vm/pkey-helpers.h
new file mode 100644
index 0000000..622a858
--- /dev/null
+++ b/tools/testing/selftests/vm/pkey-helpers.h
@@ -0,0 +1,225 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+#ifndef _PKEYS_HELPER_H
+#define _PKEYS_HELPER_H
+#define _GNU_SOURCE
+#include <string.h>
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdint.h>
+#include <stdbool.h>
+#include <signal.h>
+#include <assert.h>
+#include <stdlib.h>
+#include <ucontext.h>
+#include <sys/mman.h>
+
+/* Define some kernel-like types */
+#define  u8 __u8
+#define u16 __u16
+#define u32 __u32
+#define u64 __u64
+
+#define PTR_ERR_ENOTSUP ((void *)-ENOTSUP)
+
+#ifndef DEBUG_LEVEL
+#define DEBUG_LEVEL 0
+#endif
+#define DPRINT_IN_SIGNAL_BUF_SIZE 4096
+extern int dprint_in_signal;
+extern char dprint_in_signal_buffer[DPRINT_IN_SIGNAL_BUF_SIZE];
+
+extern int test_nr;
+extern int iteration_nr;
+
+#ifdef __GNUC__
+__attribute__((format(printf, 1, 2)))
+#endif
+static inline void sigsafe_printf(const char *format, ...)
+{
+	va_list ap;
+
+	if (!dprint_in_signal) {
+		va_start(ap, format);
+		vprintf(format, ap);
+		va_end(ap);
+	} else {
+		int ret;
+		/*
+		 * No printf() functions are signal-safe.
+		 * They deadlock easily. Write the format
+		 * string to get some output, even if
+		 * incomplete.
+		 */
+		ret = write(1, format, strlen(format));
+		if (ret < 0)
+			exit(1);
+	}
+}
+#define dprintf_level(level, args...) do {	\
+	if (level <= DEBUG_LEVEL)		\
+		sigsafe_printf(args);		\
+} while (0)
+#define dprintf0(args...) dprintf_level(0, args)
+#define dprintf1(args...) dprintf_level(1, args)
+#define dprintf2(args...) dprintf_level(2, args)
+#define dprintf3(args...) dprintf_level(3, args)
+#define dprintf4(args...) dprintf_level(4, args)
+
+extern void abort_hooks(void);
+#define pkey_assert(condition) do {		\
+	if (!(condition)) {			\
+		dprintf0("assert() at %s::%d test_nr: %d iteration: %d\n", \
+				__FILE__, __LINE__,	\
+				test_nr, iteration_nr);	\
+		dprintf0("errno at assert: %d", errno);	\
+		abort_hooks();			\
+		exit(__LINE__);			\
+	}					\
+} while (0)
+
+__attribute__((noinline)) int read_ptr(int *ptr);
+void expected_pkey_fault(int pkey);
+int sys_pkey_alloc(unsigned long flags, unsigned long init_val);
+int sys_pkey_free(unsigned long pkey);
+int mprotect_pkey(void *ptr, size_t size, unsigned long orig_prot,
+		unsigned long pkey);
+void record_pkey_malloc(void *ptr, long size, int prot);
+
+#if defined(__i386__) || defined(__x86_64__) /* arch */
+#include "pkey-x86.h"
+#elif defined(__powerpc64__) /* arch */
+#include "pkey-powerpc.h"
+#else /* arch */
+#error Architecture not supported
+#endif /* arch */
+
+#define PKEY_MASK	(PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE)
+
+static inline u64 set_pkey_bits(u64 reg, int pkey, u64 flags)
+{
+	u32 shift = pkey_bit_position(pkey);
+	/* mask out bits from pkey in old value */
+	reg &= ~((u64)PKEY_MASK << shift);
+	/* OR in new bits for pkey */
+	reg |= (flags & PKEY_MASK) << shift;
+	return reg;
+}
+
+static inline u64 get_pkey_bits(u64 reg, int pkey)
+{
+	u32 shift = pkey_bit_position(pkey);
+	/*
+	 * shift down the relevant bits to the lowest two, then
+	 * mask off all the other higher bits
+	 */
+	return ((reg >> shift) & PKEY_MASK);
+}
+
+extern u64 shadow_pkey_reg;
+
+static inline u64 _read_pkey_reg(int line)
+{
+	u64 pkey_reg = __read_pkey_reg();
+
+	dprintf4("read_pkey_reg(line=%d) pkey_reg: %016llx"
+			" shadow: %016llx\n",
+			line, pkey_reg, shadow_pkey_reg);
+	assert(pkey_reg == shadow_pkey_reg);
+
+	return pkey_reg;
+}
+
+#define read_pkey_reg() _read_pkey_reg(__LINE__)
+
+static inline void write_pkey_reg(u64 pkey_reg)
+{
+	dprintf4("%s() changing %016llx to %016llx\n", __func__,
+			__read_pkey_reg(), pkey_reg);
+	/* will do the shadow check for us: */
+	read_pkey_reg();
+	__write_pkey_reg(pkey_reg);
+	shadow_pkey_reg = pkey_reg;
+	dprintf4("%s(%016llx) pkey_reg: %016llx\n", __func__,
+			pkey_reg, __read_pkey_reg());
+}
+
+/*
+ * These are technically racy. since something could
+ * change PKEY register between the read and the write.
+ */
+static inline void __pkey_access_allow(int pkey, int do_allow)
+{
+	u64 pkey_reg = read_pkey_reg();
+	int bit = pkey * 2;
+
+	if (do_allow)
+		pkey_reg &= (1<<bit);
+	else
+		pkey_reg |= (1<<bit);
+
+	dprintf4("pkey_reg now: %016llx\n", read_pkey_reg());
+	write_pkey_reg(pkey_reg);
+}
+
+static inline void __pkey_write_allow(int pkey, int do_allow_write)
+{
+	u64 pkey_reg = read_pkey_reg();
+	int bit = pkey * 2 + 1;
+
+	if (do_allow_write)
+		pkey_reg &= (1<<bit);
+	else
+		pkey_reg |= (1<<bit);
+
+	write_pkey_reg(pkey_reg);
+	dprintf4("pkey_reg now: %016llx\n", read_pkey_reg());
+}
+
+#define ARRAY_SIZE(x) (sizeof(x) / sizeof(*(x)))
+#define ALIGN_UP(x, align_to)	(((x) + ((align_to)-1)) & ~((align_to)-1))
+#define ALIGN_DOWN(x, align_to) ((x) & ~((align_to)-1))
+#define ALIGN_PTR_UP(p, ptr_align_to)	\
+	((typeof(p))ALIGN_UP((unsigned long)(p), ptr_align_to))
+#define ALIGN_PTR_DOWN(p, ptr_align_to)	\
+	((typeof(p))ALIGN_DOWN((unsigned long)(p), ptr_align_to))
+#define __stringify_1(x...)     #x
+#define __stringify(x...)       __stringify_1(x)
+
+static inline u32 *siginfo_get_pkey_ptr(siginfo_t *si)
+{
+#ifdef si_pkey
+	return &si->si_pkey;
+#else
+	return (u32 *)(((u8 *)si) + si_pkey_offset);
+#endif
+}
+
+static inline int kernel_has_pkeys(void)
+{
+	/* try allocating a key and see if it succeeds */
+	int ret = sys_pkey_alloc(0, 0);
+	if (ret <= 0) {
+		return 0;
+	}
+	sys_pkey_free(ret);
+	return 1;
+}
+
+static inline int is_pkeys_supported(void)
+{
+	/* check if the cpu supports pkeys */
+	if (!cpu_has_pkeys()) {
+		dprintf1("SKIP: %s: no CPU support\n", __func__);
+		return 0;
+	}
+
+	/* check if the kernel supports pkeys */
+	if (!kernel_has_pkeys()) {
+		dprintf1("SKIP: %s: no kernel support\n", __func__);
+		return 0;
+	}
+
+	return 1;
+}
+
+#endif /* _PKEYS_HELPER_H */
diff --git a/tools/testing/selftests/vm/pkey-powerpc.h b/tools/testing/selftests/vm/pkey-powerpc.h
new file mode 100644
index 0000000..1ebb586
--- /dev/null
+++ b/tools/testing/selftests/vm/pkey-powerpc.h
@@ -0,0 +1,133 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+
+#ifndef _PKEYS_POWERPC_H
+#define _PKEYS_POWERPC_H
+
+#ifndef SYS_mprotect_key
+# define SYS_mprotect_key	386
+#endif
+#ifndef SYS_pkey_alloc
+# define SYS_pkey_alloc		384
+# define SYS_pkey_free		385
+#endif
+#define REG_IP_IDX		PT_NIP
+#define REG_TRAPNO		PT_TRAP
+#define gregs			gp_regs
+#define fpregs			fp_regs
+#define si_pkey_offset		0x20
+
+#undef PKEY_DISABLE_ACCESS
+#define PKEY_DISABLE_ACCESS	0x3  /* disable read and write */
+
+#undef PKEY_DISABLE_WRITE
+#define PKEY_DISABLE_WRITE	0x2
+
+#define NR_PKEYS		32
+#define NR_RESERVED_PKEYS_4K	27 /* pkey-0, pkey-1, exec-only-pkey
+				      and 24 other keys that cannot be
+				      represented in the PTE */
+#define NR_RESERVED_PKEYS_64K_3KEYS	3 /* PowerNV and KVM: pkey-0,
+					     pkey-1 and exec-only key */
+#define NR_RESERVED_PKEYS_64K_4KEYS	4 /* PowerVM: pkey-0, pkey-1,
+					     pkey-31 and exec-only key */
+#define PKEY_BITS_PER_PKEY	2
+#define HPAGE_SIZE		(1UL << 24)
+#define PAGE_SIZE		sysconf(_SC_PAGESIZE)
+
+static inline u32 pkey_bit_position(int pkey)
+{
+	return (NR_PKEYS - pkey - 1) * PKEY_BITS_PER_PKEY;
+}
+
+static inline u64 __read_pkey_reg(void)
+{
+	u64 pkey_reg;
+
+	asm volatile("mfspr %0, 0xd" : "=r" (pkey_reg));
+
+	return pkey_reg;
+}
+
+static inline void __write_pkey_reg(u64 pkey_reg)
+{
+	u64 amr = pkey_reg;
+
+	dprintf4("%s() changing %016llx to %016llx\n",
+			 __func__, __read_pkey_reg(), pkey_reg);
+
+	asm volatile("isync; mtspr 0xd, %0; isync"
+		     : : "r" ((unsigned long)(amr)) : "memory");
+
+	dprintf4("%s() pkey register after changing %016llx to %016llx\n",
+			__func__, __read_pkey_reg(), pkey_reg);
+}
+
+static inline int cpu_has_pkeys(void)
+{
+	/* No simple way to determine this */
+	return 1;
+}
+
+static inline bool arch_is_powervm()
+{
+	struct stat buf;
+
+	if ((stat("/sys/firmware/devicetree/base/ibm,partition-name", &buf) == 0) &&
+	    (stat("/sys/firmware/devicetree/base/hmc-managed?", &buf) == 0) &&
+	    (stat("/sys/firmware/devicetree/base/chosen/qemu,graphic-width", &buf) == -1) )
+		return true;
+
+	return false;
+}
+
+static inline int get_arch_reserved_keys(void)
+{
+	if (sysconf(_SC_PAGESIZE) == 4096)
+		return NR_RESERVED_PKEYS_4K;
+	else
+		if (arch_is_powervm())
+			return NR_RESERVED_PKEYS_64K_4KEYS;
+		else
+			return NR_RESERVED_PKEYS_64K_3KEYS;
+}
+
+void expect_fault_on_read_execonly_key(void *p1, int pkey)
+{
+	/*
+	 * powerpc does not allow userspace to change permissions of exec-only
+	 * keys since those keys are not allocated by userspace. The signal
+	 * handler wont be able to reset the permissions, which means the code
+	 * will infinitely continue to segfault here.
+	 */
+	return;
+}
+
+/* 4-byte instructions * 16384 = 64K page */
+#define __page_o_noops() asm(".rept 16384 ; nop; .endr")
+
+void *malloc_pkey_with_mprotect_subpage(long size, int prot, u16 pkey)
+{
+	void *ptr;
+	int ret;
+
+	dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
+			size, prot, pkey);
+	pkey_assert(pkey < NR_PKEYS);
+	ptr = mmap(NULL, size, prot, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
+	pkey_assert(ptr != (void *)-1);
+
+	ret = syscall(__NR_subpage_prot, ptr, size, NULL);
+	if (ret) {
+		perror("subpage_perm");
+		return PTR_ERR_ENOTSUP;
+	}
+
+	ret = mprotect_pkey((void *)ptr, PAGE_SIZE, prot, pkey);
+	pkey_assert(!ret);
+	record_pkey_malloc(ptr, size, prot);
+
+	dprintf1("%s() for pkey %d @ %p\n", __func__, pkey, ptr);
+	return ptr;
+}
+
+#endif /* _PKEYS_POWERPC_H */
diff --git a/tools/testing/selftests/vm/pkey-x86.h b/tools/testing/selftests/vm/pkey-x86.h
new file mode 100644
index 0000000..3be20f5
--- /dev/null
+++ b/tools/testing/selftests/vm/pkey-x86.h
@@ -0,0 +1,181 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+
+#ifndef _PKEYS_X86_H
+#define _PKEYS_X86_H
+
+#ifdef __i386__
+
+#ifndef SYS_mprotect_key
+# define SYS_mprotect_key	380
+#endif
+
+#ifndef SYS_pkey_alloc
+# define SYS_pkey_alloc		381
+# define SYS_pkey_free		382
+#endif
+
+#define REG_IP_IDX		REG_EIP
+#define si_pkey_offset		0x14
+
+#else
+
+#ifndef SYS_mprotect_key
+# define SYS_mprotect_key	329
+#endif
+
+#ifndef SYS_pkey_alloc
+# define SYS_pkey_alloc		330
+# define SYS_pkey_free		331
+#endif
+
+#define REG_IP_IDX		REG_RIP
+#define si_pkey_offset		0x20
+
+#endif
+
+#ifndef PKEY_DISABLE_ACCESS
+# define PKEY_DISABLE_ACCESS	0x1
+#endif
+
+#ifndef PKEY_DISABLE_WRITE
+# define PKEY_DISABLE_WRITE	0x2
+#endif
+
+#define NR_PKEYS		16
+#define NR_RESERVED_PKEYS	2 /* pkey-0 and exec-only-pkey */
+#define PKEY_BITS_PER_PKEY	2
+#define HPAGE_SIZE		(1UL<<21)
+#define PAGE_SIZE		4096
+#define MB			(1<<20)
+
+static inline void __page_o_noops(void)
+{
+	/* 8-bytes of instruction * 512 bytes = 1 page */
+	asm(".rept 512 ; nopl 0x7eeeeeee(%eax) ; .endr");
+}
+
+static inline u64 __read_pkey_reg(void)
+{
+	unsigned int eax, edx;
+	unsigned int ecx = 0;
+	unsigned pkey_reg;
+
+	asm volatile(".byte 0x0f,0x01,0xee\n\t"
+		     : "=a" (eax), "=d" (edx)
+		     : "c" (ecx));
+	pkey_reg = eax;
+	return pkey_reg;
+}
+
+static inline void __write_pkey_reg(u64 pkey_reg)
+{
+	unsigned int eax = pkey_reg;
+	unsigned int ecx = 0;
+	unsigned int edx = 0;
+
+	dprintf4("%s() changing %016llx to %016llx\n", __func__,
+			__read_pkey_reg(), pkey_reg);
+	asm volatile(".byte 0x0f,0x01,0xef\n\t"
+		     : : "a" (eax), "c" (ecx), "d" (edx));
+	assert(pkey_reg == __read_pkey_reg());
+}
+
+static inline void __cpuid(unsigned int *eax, unsigned int *ebx,
+		unsigned int *ecx, unsigned int *edx)
+{
+	/* ecx is often an input as well as an output. */
+	asm volatile(
+		"cpuid;"
+		: "=a" (*eax),
+		  "=b" (*ebx),
+		  "=c" (*ecx),
+		  "=d" (*edx)
+		: "0" (*eax), "2" (*ecx));
+}
+
+/* Intel-defined CPU features, CPUID level 0x00000007:0 (ecx) */
+#define X86_FEATURE_PKU        (1<<3) /* Protection Keys for Userspace */
+#define X86_FEATURE_OSPKE      (1<<4) /* OS Protection Keys Enable */
+
+static inline int cpu_has_pkeys(void)
+{
+	unsigned int eax;
+	unsigned int ebx;
+	unsigned int ecx;
+	unsigned int edx;
+
+	eax = 0x7;
+	ecx = 0x0;
+	__cpuid(&eax, &ebx, &ecx, &edx);
+
+	if (!(ecx & X86_FEATURE_PKU)) {
+		dprintf2("cpu does not have PKU\n");
+		return 0;
+	}
+	if (!(ecx & X86_FEATURE_OSPKE)) {
+		dprintf2("cpu does not have OSPKE\n");
+		return 0;
+	}
+	return 1;
+}
+
+static inline u32 pkey_bit_position(int pkey)
+{
+	return pkey * PKEY_BITS_PER_PKEY;
+}
+
+#define XSTATE_PKEY_BIT	(9)
+#define XSTATE_PKEY	0x200
+
+int pkey_reg_xstate_offset(void)
+{
+	unsigned int eax;
+	unsigned int ebx;
+	unsigned int ecx;
+	unsigned int edx;
+	int xstate_offset;
+	int xstate_size;
+	unsigned long XSTATE_CPUID = 0xd;
+	int leaf;
+
+	/* assume that XSTATE_PKEY is set in XCR0 */
+	leaf = XSTATE_PKEY_BIT;
+	{
+		eax = XSTATE_CPUID;
+		ecx = leaf;
+		__cpuid(&eax, &ebx, &ecx, &edx);
+
+		if (leaf == XSTATE_PKEY_BIT) {
+			xstate_offset = ebx;
+			xstate_size = eax;
+		}
+	}
+
+	if (xstate_size == 0) {
+		printf("could not find size/offset of PKEY in xsave state\n");
+		return 0;
+	}
+
+	return xstate_offset;
+}
+
+static inline int get_arch_reserved_keys(void)
+{
+	return NR_RESERVED_PKEYS;
+}
+
+void expect_fault_on_read_execonly_key(void *p1, int pkey)
+{
+	int ptr_contents;
+
+	ptr_contents = read_ptr(p1);
+	dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
+	expected_pkey_fault(pkey);
+}
+
+void *malloc_pkey_with_mprotect_subpage(long size, int prot, u16 pkey)
+{
+	return PTR_ERR_ENOTSUP;
+}
+
+#endif /* _PKEYS_X86_H */
diff --git a/tools/testing/selftests/vm/protection_keys.c b/tools/testing/selftests/vm/protection_keys.c
new file mode 100644
index 0000000..87eecd5
--- /dev/null
+++ b/tools/testing/selftests/vm/protection_keys.c
@@ -0,0 +1,1588 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Tests Memory Protection Keys (see Documentation/core-api/protection-keys.rst)
+ *
+ * There are examples in here of:
+ *  * how to set protection keys on memory
+ *  * how to set/clear bits in pkey registers (the rights register)
+ *  * how to handle SEGV_PKUERR signals and extract pkey-relevant
+ *    information from the siginfo
+ *
+ * Things to add:
+ *	make sure KSM and KSM COW breaking works
+ *	prefault pages in at malloc, or not
+ *	protect MPX bounds tables with protection keys?
+ *	make sure VMA splitting/merging is working correctly
+ *	OOMs can destroy mm->mmap (see exit_mmap()), so make sure it is immune to pkeys
+ *	look for pkey "leaks" where it is still set on a VMA but "freed" back to the kernel
+ *	do a plain mprotect() to a mprotect_pkey() area and make sure the pkey sticks
+ *
+ * Compile like this:
+ *	gcc      -o protection_keys    -O2 -g -std=gnu99 -pthread -Wall protection_keys.c -lrt -ldl -lm
+ *	gcc -m32 -o protection_keys_32 -O2 -g -std=gnu99 -pthread -Wall protection_keys.c -lrt -ldl -lm
+ */
+#define _GNU_SOURCE
+#define __SANE_USERSPACE_TYPES__
+#include <errno.h>
+#include <linux/futex.h>
+#include <time.h>
+#include <sys/time.h>
+#include <sys/syscall.h>
+#include <string.h>
+#include <stdio.h>
+#include <stdint.h>
+#include <stdbool.h>
+#include <signal.h>
+#include <assert.h>
+#include <stdlib.h>
+#include <ucontext.h>
+#include <sys/mman.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include <unistd.h>
+#include <sys/ptrace.h>
+#include <setjmp.h>
+
+#include "pkey-helpers.h"
+
+int iteration_nr = 1;
+int test_nr;
+
+u64 shadow_pkey_reg;
+int dprint_in_signal;
+char dprint_in_signal_buffer[DPRINT_IN_SIGNAL_BUF_SIZE];
+
+void cat_into_file(char *str, char *file)
+{
+	int fd = open(file, O_RDWR);
+	int ret;
+
+	dprintf2("%s(): writing '%s' to '%s'\n", __func__, str, file);
+	/*
+	 * these need to be raw because they are called under
+	 * pkey_assert()
+	 */
+	if (fd < 0) {
+		fprintf(stderr, "error opening '%s'\n", str);
+		perror("error: ");
+		exit(__LINE__);
+	}
+
+	ret = write(fd, str, strlen(str));
+	if (ret != strlen(str)) {
+		perror("write to file failed");
+		fprintf(stderr, "filename: '%s' str: '%s'\n", file, str);
+		exit(__LINE__);
+	}
+	close(fd);
+}
+
+#if CONTROL_TRACING > 0
+static int warned_tracing;
+int tracing_root_ok(void)
+{
+	if (geteuid() != 0) {
+		if (!warned_tracing)
+			fprintf(stderr, "WARNING: not run as root, "
+					"can not do tracing control\n");
+		warned_tracing = 1;
+		return 0;
+	}
+	return 1;
+}
+#endif
+
+void tracing_on(void)
+{
+#if CONTROL_TRACING > 0
+#define TRACEDIR "/sys/kernel/debug/tracing"
+	char pidstr[32];
+
+	if (!tracing_root_ok())
+		return;
+
+	sprintf(pidstr, "%d", getpid());
+	cat_into_file("0", TRACEDIR "/tracing_on");
+	cat_into_file("\n", TRACEDIR "/trace");
+	if (1) {
+		cat_into_file("function_graph", TRACEDIR "/current_tracer");
+		cat_into_file("1", TRACEDIR "/options/funcgraph-proc");
+	} else {
+		cat_into_file("nop", TRACEDIR "/current_tracer");
+	}
+	cat_into_file(pidstr, TRACEDIR "/set_ftrace_pid");
+	cat_into_file("1", TRACEDIR "/tracing_on");
+	dprintf1("enabled tracing\n");
+#endif
+}
+
+void tracing_off(void)
+{
+#if CONTROL_TRACING > 0
+	if (!tracing_root_ok())
+		return;
+	cat_into_file("0", "/sys/kernel/debug/tracing/tracing_on");
+#endif
+}
+
+void abort_hooks(void)
+{
+	fprintf(stderr, "running %s()...\n", __func__);
+	tracing_off();
+#ifdef SLEEP_ON_ABORT
+	sleep(SLEEP_ON_ABORT);
+#endif
+}
+
+/*
+ * This attempts to have roughly a page of instructions followed by a few
+ * instructions that do a write, and another page of instructions.  That
+ * way, we are pretty sure that the write is in the second page of
+ * instructions and has at least a page of padding behind it.
+ *
+ * *That* lets us be sure to madvise() away the write instruction, which
+ * will then fault, which makes sure that the fault code handles
+ * execute-only memory properly.
+ */
+#ifdef __powerpc64__
+/* This way, both 4K and 64K alignment are maintained */
+__attribute__((__aligned__(65536)))
+#else
+__attribute__((__aligned__(PAGE_SIZE)))
+#endif
+void lots_o_noops_around_write(int *write_to_me)
+{
+	dprintf3("running %s()\n", __func__);
+	__page_o_noops();
+	/* Assume this happens in the second page of instructions: */
+	*write_to_me = __LINE__;
+	/* pad out by another page: */
+	__page_o_noops();
+	dprintf3("%s() done\n", __func__);
+}
+
+void dump_mem(void *dumpme, int len_bytes)
+{
+	char *c = (void *)dumpme;
+	int i;
+
+	for (i = 0; i < len_bytes; i += sizeof(u64)) {
+		u64 *ptr = (u64 *)(c + i);
+		dprintf1("dump[%03d][@%p]: %016llx\n", i, ptr, *ptr);
+	}
+}
+
+static u32 hw_pkey_get(int pkey, unsigned long flags)
+{
+	u64 pkey_reg = __read_pkey_reg();
+
+	dprintf1("%s(pkey=%d, flags=%lx) = %x / %d\n",
+			__func__, pkey, flags, 0, 0);
+	dprintf2("%s() raw pkey_reg: %016llx\n", __func__, pkey_reg);
+
+	return (u32) get_pkey_bits(pkey_reg, pkey);
+}
+
+static int hw_pkey_set(int pkey, unsigned long rights, unsigned long flags)
+{
+	u32 mask = (PKEY_DISABLE_ACCESS|PKEY_DISABLE_WRITE);
+	u64 old_pkey_reg = __read_pkey_reg();
+	u64 new_pkey_reg;
+
+	/* make sure that 'rights' only contains the bits we expect: */
+	assert(!(rights & ~mask));
+
+	/* modify bits accordingly in old pkey_reg and assign it */
+	new_pkey_reg = set_pkey_bits(old_pkey_reg, pkey, rights);
+
+	__write_pkey_reg(new_pkey_reg);
+
+	dprintf3("%s(pkey=%d, rights=%lx, flags=%lx) = %x"
+		" pkey_reg now: %016llx old_pkey_reg: %016llx\n",
+		__func__, pkey, rights, flags, 0, __read_pkey_reg(),
+		old_pkey_reg);
+	return 0;
+}
+
+void pkey_disable_set(int pkey, int flags)
+{
+	unsigned long syscall_flags = 0;
+	int ret;
+	int pkey_rights;
+	u64 orig_pkey_reg = read_pkey_reg();
+
+	dprintf1("START->%s(%d, 0x%x)\n", __func__,
+		pkey, flags);
+	pkey_assert(flags & (PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE));
+
+	pkey_rights = hw_pkey_get(pkey, syscall_flags);
+
+	dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
+			pkey, pkey, pkey_rights);
+
+	pkey_assert(pkey_rights >= 0);
+
+	pkey_rights |= flags;
+
+	ret = hw_pkey_set(pkey, pkey_rights, syscall_flags);
+	assert(!ret);
+	/* pkey_reg and flags have the same format */
+	shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, pkey, pkey_rights);
+	dprintf1("%s(%d) shadow: 0x%016llx\n",
+		__func__, pkey, shadow_pkey_reg);
+
+	pkey_assert(ret >= 0);
+
+	pkey_rights = hw_pkey_get(pkey, syscall_flags);
+	dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
+			pkey, pkey, pkey_rights);
+
+	dprintf1("%s(%d) pkey_reg: 0x%016llx\n",
+		__func__, pkey, read_pkey_reg());
+	if (flags)
+		pkey_assert(read_pkey_reg() >= orig_pkey_reg);
+	dprintf1("END<---%s(%d, 0x%x)\n", __func__,
+		pkey, flags);
+}
+
+void pkey_disable_clear(int pkey, int flags)
+{
+	unsigned long syscall_flags = 0;
+	int ret;
+	int pkey_rights = hw_pkey_get(pkey, syscall_flags);
+	u64 orig_pkey_reg = read_pkey_reg();
+
+	pkey_assert(flags & (PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE));
+
+	dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
+			pkey, pkey, pkey_rights);
+	pkey_assert(pkey_rights >= 0);
+
+	pkey_rights &= ~flags;
+
+	ret = hw_pkey_set(pkey, pkey_rights, 0);
+	shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, pkey, pkey_rights);
+	pkey_assert(ret >= 0);
+
+	pkey_rights = hw_pkey_get(pkey, syscall_flags);
+	dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
+			pkey, pkey, pkey_rights);
+
+	dprintf1("%s(%d) pkey_reg: 0x%016llx\n", __func__,
+			pkey, read_pkey_reg());
+	if (flags)
+		assert(read_pkey_reg() <= orig_pkey_reg);
+}
+
+void pkey_write_allow(int pkey)
+{
+	pkey_disable_clear(pkey, PKEY_DISABLE_WRITE);
+}
+void pkey_write_deny(int pkey)
+{
+	pkey_disable_set(pkey, PKEY_DISABLE_WRITE);
+}
+void pkey_access_allow(int pkey)
+{
+	pkey_disable_clear(pkey, PKEY_DISABLE_ACCESS);
+}
+void pkey_access_deny(int pkey)
+{
+	pkey_disable_set(pkey, PKEY_DISABLE_ACCESS);
+}
+
+/* Failed address bound checks: */
+#ifndef SEGV_BNDERR
+# define SEGV_BNDERR		3
+#endif
+
+#ifndef SEGV_PKUERR
+# define SEGV_PKUERR		4
+#endif
+
+static char *si_code_str(int si_code)
+{
+	if (si_code == SEGV_MAPERR)
+		return "SEGV_MAPERR";
+	if (si_code == SEGV_ACCERR)
+		return "SEGV_ACCERR";
+	if (si_code == SEGV_BNDERR)
+		return "SEGV_BNDERR";
+	if (si_code == SEGV_PKUERR)
+		return "SEGV_PKUERR";
+	return "UNKNOWN";
+}
+
+int pkey_faults;
+int last_si_pkey = -1;
+void signal_handler(int signum, siginfo_t *si, void *vucontext)
+{
+	ucontext_t *uctxt = vucontext;
+	int trapno;
+	unsigned long ip;
+	char *fpregs;
+#if defined(__i386__) || defined(__x86_64__) /* arch */
+	u32 *pkey_reg_ptr;
+	int pkey_reg_offset;
+#endif /* arch */
+	u64 siginfo_pkey;
+	u32 *si_pkey_ptr;
+
+	dprint_in_signal = 1;
+	dprintf1(">>>>===============SIGSEGV============================\n");
+	dprintf1("%s()::%d, pkey_reg: 0x%016llx shadow: %016llx\n",
+			__func__, __LINE__,
+			__read_pkey_reg(), shadow_pkey_reg);
+
+	trapno = uctxt->uc_mcontext.gregs[REG_TRAPNO];
+	ip = uctxt->uc_mcontext.gregs[REG_IP_IDX];
+	fpregs = (char *) uctxt->uc_mcontext.fpregs;
+
+	dprintf2("%s() trapno: %d ip: 0x%016lx info->si_code: %s/%d\n",
+			__func__, trapno, ip, si_code_str(si->si_code),
+			si->si_code);
+
+#if defined(__i386__) || defined(__x86_64__) /* arch */
+#ifdef __i386__
+	/*
+	 * 32-bit has some extra padding so that userspace can tell whether
+	 * the XSTATE header is present in addition to the "legacy" FPU
+	 * state.  We just assume that it is here.
+	 */
+	fpregs += 0x70;
+#endif /* i386 */
+	pkey_reg_offset = pkey_reg_xstate_offset();
+	pkey_reg_ptr = (void *)(&fpregs[pkey_reg_offset]);
+
+	/*
+	 * If we got a PKEY fault, we *HAVE* to have at least one bit set in
+	 * here.
+	 */
+	dprintf1("pkey_reg_xstate_offset: %d\n", pkey_reg_xstate_offset());
+	if (DEBUG_LEVEL > 4)
+		dump_mem(pkey_reg_ptr - 128, 256);
+	pkey_assert(*pkey_reg_ptr);
+#endif /* arch */
+
+	dprintf1("siginfo: %p\n", si);
+	dprintf1(" fpregs: %p\n", fpregs);
+
+	if ((si->si_code == SEGV_MAPERR) ||
+	    (si->si_code == SEGV_ACCERR) ||
+	    (si->si_code == SEGV_BNDERR)) {
+		printf("non-PK si_code, exiting...\n");
+		exit(4);
+	}
+
+	si_pkey_ptr = siginfo_get_pkey_ptr(si);
+	dprintf1("si_pkey_ptr: %p\n", si_pkey_ptr);
+	dump_mem((u8 *)si_pkey_ptr - 8, 24);
+	siginfo_pkey = *si_pkey_ptr;
+	pkey_assert(siginfo_pkey < NR_PKEYS);
+	last_si_pkey = siginfo_pkey;
+
+	/*
+	 * need __read_pkey_reg() version so we do not do shadow_pkey_reg
+	 * checking
+	 */
+	dprintf1("signal pkey_reg from  pkey_reg: %016llx\n",
+			__read_pkey_reg());
+	dprintf1("pkey from siginfo: %016llx\n", siginfo_pkey);
+#if defined(__i386__) || defined(__x86_64__) /* arch */
+	dprintf1("signal pkey_reg from xsave: %08x\n", *pkey_reg_ptr);
+	*(u64 *)pkey_reg_ptr = 0x00000000;
+	dprintf1("WARNING: set PKEY_REG=0 to allow faulting instruction to continue\n");
+#elif defined(__powerpc64__) /* arch */
+	/* restore access and let the faulting instruction continue */
+	pkey_access_allow(siginfo_pkey);
+#endif /* arch */
+	pkey_faults++;
+	dprintf1("<<<<==================================================\n");
+	dprint_in_signal = 0;
+}
+
+int wait_all_children(void)
+{
+	int status;
+	return waitpid(-1, &status, 0);
+}
+
+void sig_chld(int x)
+{
+	dprint_in_signal = 1;
+	dprintf2("[%d] SIGCHLD: %d\n", getpid(), x);
+	dprint_in_signal = 0;
+}
+
+void setup_sigsegv_handler(void)
+{
+	int r, rs;
+	struct sigaction newact;
+	struct sigaction oldact;
+
+	/* #PF is mapped to sigsegv */
+	int signum  = SIGSEGV;
+
+	newact.sa_handler = 0;
+	newact.sa_sigaction = signal_handler;
+
+	/*sigset_t - signals to block while in the handler */
+	/* get the old signal mask. */
+	rs = sigprocmask(SIG_SETMASK, 0, &newact.sa_mask);
+	pkey_assert(rs == 0);
+
+	/* call sa_sigaction, not sa_handler*/
+	newact.sa_flags = SA_SIGINFO;
+
+	newact.sa_restorer = 0;  /* void(*)(), obsolete */
+	r = sigaction(signum, &newact, &oldact);
+	r = sigaction(SIGALRM, &newact, &oldact);
+	pkey_assert(r == 0);
+}
+
+void setup_handlers(void)
+{
+	signal(SIGCHLD, &sig_chld);
+	setup_sigsegv_handler();
+}
+
+pid_t fork_lazy_child(void)
+{
+	pid_t forkret;
+
+	forkret = fork();
+	pkey_assert(forkret >= 0);
+	dprintf3("[%d] fork() ret: %d\n", getpid(), forkret);
+
+	if (!forkret) {
+		/* in the child */
+		while (1) {
+			dprintf1("child sleeping...\n");
+			sleep(30);
+		}
+	}
+	return forkret;
+}
+
+int sys_mprotect_pkey(void *ptr, size_t size, unsigned long orig_prot,
+		unsigned long pkey)
+{
+	int sret;
+
+	dprintf2("%s(0x%p, %zx, prot=%lx, pkey=%lx)\n", __func__,
+			ptr, size, orig_prot, pkey);
+
+	errno = 0;
+	sret = syscall(SYS_mprotect_key, ptr, size, orig_prot, pkey);
+	if (errno) {
+		dprintf2("SYS_mprotect_key sret: %d\n", sret);
+		dprintf2("SYS_mprotect_key prot: 0x%lx\n", orig_prot);
+		dprintf2("SYS_mprotect_key failed, errno: %d\n", errno);
+		if (DEBUG_LEVEL >= 2)
+			perror("SYS_mprotect_pkey");
+	}
+	return sret;
+}
+
+int sys_pkey_alloc(unsigned long flags, unsigned long init_val)
+{
+	int ret = syscall(SYS_pkey_alloc, flags, init_val);
+	dprintf1("%s(flags=%lx, init_val=%lx) syscall ret: %d errno: %d\n",
+			__func__, flags, init_val, ret, errno);
+	return ret;
+}
+
+int alloc_pkey(void)
+{
+	int ret;
+	unsigned long init_val = 0x0;
+
+	dprintf1("%s()::%d, pkey_reg: 0x%016llx shadow: %016llx\n",
+			__func__, __LINE__, __read_pkey_reg(), shadow_pkey_reg);
+	ret = sys_pkey_alloc(0, init_val);
+	/*
+	 * pkey_alloc() sets PKEY register, so we need to reflect it in
+	 * shadow_pkey_reg:
+	 */
+	dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
+			" shadow: 0x%016llx\n",
+			__func__, __LINE__, ret, __read_pkey_reg(),
+			shadow_pkey_reg);
+	if (ret > 0) {
+		/* clear both the bits: */
+		shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, ret,
+						~PKEY_MASK);
+		dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
+				" shadow: 0x%016llx\n",
+				__func__,
+				__LINE__, ret, __read_pkey_reg(),
+				shadow_pkey_reg);
+		/*
+		 * move the new state in from init_val
+		 * (remember, we cheated and init_val == pkey_reg format)
+		 */
+		shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, ret,
+						init_val);
+	}
+	dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
+			" shadow: 0x%016llx\n",
+			__func__, __LINE__, ret, __read_pkey_reg(),
+			shadow_pkey_reg);
+	dprintf1("%s()::%d errno: %d\n", __func__, __LINE__, errno);
+	/* for shadow checking: */
+	read_pkey_reg();
+	dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
+		 " shadow: 0x%016llx\n",
+		__func__, __LINE__, ret, __read_pkey_reg(),
+		shadow_pkey_reg);
+	return ret;
+}
+
+int sys_pkey_free(unsigned long pkey)
+{
+	int ret = syscall(SYS_pkey_free, pkey);
+	dprintf1("%s(pkey=%ld) syscall ret: %d\n", __func__, pkey, ret);
+	return ret;
+}
+
+/*
+ * I had a bug where pkey bits could be set by mprotect() but
+ * not cleared.  This ensures we get lots of random bit sets
+ * and clears on the vma and pte pkey bits.
+ */
+int alloc_random_pkey(void)
+{
+	int max_nr_pkey_allocs;
+	int ret;
+	int i;
+	int alloced_pkeys[NR_PKEYS];
+	int nr_alloced = 0;
+	int random_index;
+	memset(alloced_pkeys, 0, sizeof(alloced_pkeys));
+
+	/* allocate every possible key and make a note of which ones we got */
+	max_nr_pkey_allocs = NR_PKEYS;
+	for (i = 0; i < max_nr_pkey_allocs; i++) {
+		int new_pkey = alloc_pkey();
+		if (new_pkey < 0)
+			break;
+		alloced_pkeys[nr_alloced++] = new_pkey;
+	}
+
+	pkey_assert(nr_alloced > 0);
+	/* select a random one out of the allocated ones */
+	random_index = rand() % nr_alloced;
+	ret = alloced_pkeys[random_index];
+	/* now zero it out so we don't free it next */
+	alloced_pkeys[random_index] = 0;
+
+	/* go through the allocated ones that we did not want and free them */
+	for (i = 0; i < nr_alloced; i++) {
+		int free_ret;
+		if (!alloced_pkeys[i])
+			continue;
+		free_ret = sys_pkey_free(alloced_pkeys[i]);
+		pkey_assert(!free_ret);
+	}
+	dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
+			 " shadow: 0x%016llx\n", __func__,
+			__LINE__, ret, __read_pkey_reg(), shadow_pkey_reg);
+	return ret;
+}
+
+int mprotect_pkey(void *ptr, size_t size, unsigned long orig_prot,
+		unsigned long pkey)
+{
+	int nr_iterations = random() % 100;
+	int ret;
+
+	while (0) {
+		int rpkey = alloc_random_pkey();
+		ret = sys_mprotect_pkey(ptr, size, orig_prot, pkey);
+		dprintf1("sys_mprotect_pkey(%p, %zx, prot=0x%lx, pkey=%ld) ret: %d\n",
+				ptr, size, orig_prot, pkey, ret);
+		if (nr_iterations-- < 0)
+			break;
+
+		dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
+			" shadow: 0x%016llx\n",
+			__func__, __LINE__, ret, __read_pkey_reg(),
+			shadow_pkey_reg);
+		sys_pkey_free(rpkey);
+		dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
+			" shadow: 0x%016llx\n",
+			__func__, __LINE__, ret, __read_pkey_reg(),
+			shadow_pkey_reg);
+	}
+	pkey_assert(pkey < NR_PKEYS);
+
+	ret = sys_mprotect_pkey(ptr, size, orig_prot, pkey);
+	dprintf1("mprotect_pkey(%p, %zx, prot=0x%lx, pkey=%ld) ret: %d\n",
+			ptr, size, orig_prot, pkey, ret);
+	pkey_assert(!ret);
+	dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
+			" shadow: 0x%016llx\n", __func__,
+			__LINE__, ret, __read_pkey_reg(), shadow_pkey_reg);
+	return ret;
+}
+
+struct pkey_malloc_record {
+	void *ptr;
+	long size;
+	int prot;
+};
+struct pkey_malloc_record *pkey_malloc_records;
+struct pkey_malloc_record *pkey_last_malloc_record;
+long nr_pkey_malloc_records;
+void record_pkey_malloc(void *ptr, long size, int prot)
+{
+	long i;
+	struct pkey_malloc_record *rec = NULL;
+
+	for (i = 0; i < nr_pkey_malloc_records; i++) {
+		rec = &pkey_malloc_records[i];
+		/* find a free record */
+		if (rec)
+			break;
+	}
+	if (!rec) {
+		/* every record is full */
+		size_t old_nr_records = nr_pkey_malloc_records;
+		size_t new_nr_records = (nr_pkey_malloc_records * 2 + 1);
+		size_t new_size = new_nr_records * sizeof(struct pkey_malloc_record);
+		dprintf2("new_nr_records: %zd\n", new_nr_records);
+		dprintf2("new_size: %zd\n", new_size);
+		pkey_malloc_records = realloc(pkey_malloc_records, new_size);
+		pkey_assert(pkey_malloc_records != NULL);
+		rec = &pkey_malloc_records[nr_pkey_malloc_records];
+		/*
+		 * realloc() does not initialize memory, so zero it from
+		 * the first new record all the way to the end.
+		 */
+		for (i = 0; i < new_nr_records - old_nr_records; i++)
+			memset(rec + i, 0, sizeof(*rec));
+	}
+	dprintf3("filling malloc record[%d/%p]: {%p, %ld}\n",
+		(int)(rec - pkey_malloc_records), rec, ptr, size);
+	rec->ptr = ptr;
+	rec->size = size;
+	rec->prot = prot;
+	pkey_last_malloc_record = rec;
+	nr_pkey_malloc_records++;
+}
+
+void free_pkey_malloc(void *ptr)
+{
+	long i;
+	int ret;
+	dprintf3("%s(%p)\n", __func__, ptr);
+	for (i = 0; i < nr_pkey_malloc_records; i++) {
+		struct pkey_malloc_record *rec = &pkey_malloc_records[i];
+		dprintf4("looking for ptr %p at record[%ld/%p]: {%p, %ld}\n",
+				ptr, i, rec, rec->ptr, rec->size);
+		if ((ptr <  rec->ptr) ||
+		    (ptr >= rec->ptr + rec->size))
+			continue;
+
+		dprintf3("found ptr %p at record[%ld/%p]: {%p, %ld}\n",
+				ptr, i, rec, rec->ptr, rec->size);
+		nr_pkey_malloc_records--;
+		ret = munmap(rec->ptr, rec->size);
+		dprintf3("munmap ret: %d\n", ret);
+		pkey_assert(!ret);
+		dprintf3("clearing rec->ptr, rec: %p\n", rec);
+		rec->ptr = NULL;
+		dprintf3("done clearing rec->ptr, rec: %p\n", rec);
+		return;
+	}
+	pkey_assert(false);
+}
+
+
+void *malloc_pkey_with_mprotect(long size, int prot, u16 pkey)
+{
+	void *ptr;
+	int ret;
+
+	read_pkey_reg();
+	dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
+			size, prot, pkey);
+	pkey_assert(pkey < NR_PKEYS);
+	ptr = mmap(NULL, size, prot, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
+	pkey_assert(ptr != (void *)-1);
+	ret = mprotect_pkey((void *)ptr, PAGE_SIZE, prot, pkey);
+	pkey_assert(!ret);
+	record_pkey_malloc(ptr, size, prot);
+	read_pkey_reg();
+
+	dprintf1("%s() for pkey %d @ %p\n", __func__, pkey, ptr);
+	return ptr;
+}
+
+void *malloc_pkey_anon_huge(long size, int prot, u16 pkey)
+{
+	int ret;
+	void *ptr;
+
+	dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
+			size, prot, pkey);
+	/*
+	 * Guarantee we can fit at least one huge page in the resulting
+	 * allocation by allocating space for 2:
+	 */
+	size = ALIGN_UP(size, HPAGE_SIZE * 2);
+	ptr = mmap(NULL, size, PROT_NONE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
+	pkey_assert(ptr != (void *)-1);
+	record_pkey_malloc(ptr, size, prot);
+	mprotect_pkey(ptr, size, prot, pkey);
+
+	dprintf1("unaligned ptr: %p\n", ptr);
+	ptr = ALIGN_PTR_UP(ptr, HPAGE_SIZE);
+	dprintf1("  aligned ptr: %p\n", ptr);
+	ret = madvise(ptr, HPAGE_SIZE, MADV_HUGEPAGE);
+	dprintf1("MADV_HUGEPAGE ret: %d\n", ret);
+	ret = madvise(ptr, HPAGE_SIZE, MADV_WILLNEED);
+	dprintf1("MADV_WILLNEED ret: %d\n", ret);
+	memset(ptr, 0, HPAGE_SIZE);
+
+	dprintf1("mmap()'d thp for pkey %d @ %p\n", pkey, ptr);
+	return ptr;
+}
+
+int hugetlb_setup_ok;
+#define SYSFS_FMT_NR_HUGE_PAGES "/sys/kernel/mm/hugepages/hugepages-%ldkB/nr_hugepages"
+#define GET_NR_HUGE_PAGES 10
+void setup_hugetlbfs(void)
+{
+	int err;
+	int fd;
+	char buf[256];
+	long hpagesz_kb;
+	long hpagesz_mb;
+
+	if (geteuid() != 0) {
+		fprintf(stderr, "WARNING: not run as root, can not do hugetlb test\n");
+		return;
+	}
+
+	cat_into_file(__stringify(GET_NR_HUGE_PAGES), "/proc/sys/vm/nr_hugepages");
+
+	/*
+	 * Now go make sure that we got the pages and that they
+	 * are PMD-level pages. Someone might have made PUD-level
+	 * pages the default.
+	 */
+	hpagesz_kb = HPAGE_SIZE / 1024;
+	hpagesz_mb = hpagesz_kb / 1024;
+	sprintf(buf, SYSFS_FMT_NR_HUGE_PAGES, hpagesz_kb);
+	fd = open(buf, O_RDONLY);
+	if (fd < 0) {
+		fprintf(stderr, "opening sysfs %ldM hugetlb config: %s\n",
+			hpagesz_mb, strerror(errno));
+		return;
+	}
+
+	/* -1 to guarantee leaving the trailing \0 */
+	err = read(fd, buf, sizeof(buf)-1);
+	close(fd);
+	if (err <= 0) {
+		fprintf(stderr, "reading sysfs %ldM hugetlb config: %s\n",
+			hpagesz_mb, strerror(errno));
+		return;
+	}
+
+	if (atoi(buf) != GET_NR_HUGE_PAGES) {
+		fprintf(stderr, "could not confirm %ldM pages, got: '%s' expected %d\n",
+			hpagesz_mb, buf, GET_NR_HUGE_PAGES);
+		return;
+	}
+
+	hugetlb_setup_ok = 1;
+}
+
+void *malloc_pkey_hugetlb(long size, int prot, u16 pkey)
+{
+	void *ptr;
+	int flags = MAP_ANONYMOUS|MAP_PRIVATE|MAP_HUGETLB;
+
+	if (!hugetlb_setup_ok)
+		return PTR_ERR_ENOTSUP;
+
+	dprintf1("doing %s(%ld, %x, %x)\n", __func__, size, prot, pkey);
+	size = ALIGN_UP(size, HPAGE_SIZE * 2);
+	pkey_assert(pkey < NR_PKEYS);
+	ptr = mmap(NULL, size, PROT_NONE, flags, -1, 0);
+	pkey_assert(ptr != (void *)-1);
+	mprotect_pkey(ptr, size, prot, pkey);
+
+	record_pkey_malloc(ptr, size, prot);
+
+	dprintf1("mmap()'d hugetlbfs for pkey %d @ %p\n", pkey, ptr);
+	return ptr;
+}
+
+void *malloc_pkey_mmap_dax(long size, int prot, u16 pkey)
+{
+	void *ptr;
+	int fd;
+
+	dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
+			size, prot, pkey);
+	pkey_assert(pkey < NR_PKEYS);
+	fd = open("/dax/foo", O_RDWR);
+	pkey_assert(fd >= 0);
+
+	ptr = mmap(0, size, prot, MAP_SHARED, fd, 0);
+	pkey_assert(ptr != (void *)-1);
+
+	mprotect_pkey(ptr, size, prot, pkey);
+
+	record_pkey_malloc(ptr, size, prot);
+
+	dprintf1("mmap()'d for pkey %d @ %p\n", pkey, ptr);
+	close(fd);
+	return ptr;
+}
+
+void *(*pkey_malloc[])(long size, int prot, u16 pkey) = {
+
+	malloc_pkey_with_mprotect,
+	malloc_pkey_with_mprotect_subpage,
+	malloc_pkey_anon_huge,
+	malloc_pkey_hugetlb
+/* can not do direct with the pkey_mprotect() API:
+	malloc_pkey_mmap_direct,
+	malloc_pkey_mmap_dax,
+*/
+};
+
+void *malloc_pkey(long size, int prot, u16 pkey)
+{
+	void *ret;
+	static int malloc_type;
+	int nr_malloc_types = ARRAY_SIZE(pkey_malloc);
+
+	pkey_assert(pkey < NR_PKEYS);
+
+	while (1) {
+		pkey_assert(malloc_type < nr_malloc_types);
+
+		ret = pkey_malloc[malloc_type](size, prot, pkey);
+		pkey_assert(ret != (void *)-1);
+
+		malloc_type++;
+		if (malloc_type >= nr_malloc_types)
+			malloc_type = (random()%nr_malloc_types);
+
+		/* try again if the malloc_type we tried is unsupported */
+		if (ret == PTR_ERR_ENOTSUP)
+			continue;
+
+		break;
+	}
+
+	dprintf3("%s(%ld, prot=%x, pkey=%x) returning: %p\n", __func__,
+			size, prot, pkey, ret);
+	return ret;
+}
+
+int last_pkey_faults;
+#define UNKNOWN_PKEY -2
+void expected_pkey_fault(int pkey)
+{
+	dprintf2("%s(): last_pkey_faults: %d pkey_faults: %d\n",
+			__func__, last_pkey_faults, pkey_faults);
+	dprintf2("%s(%d): last_si_pkey: %d\n", __func__, pkey, last_si_pkey);
+	pkey_assert(last_pkey_faults + 1 == pkey_faults);
+
+       /*
+	* For exec-only memory, we do not know the pkey in
+	* advance, so skip this check.
+	*/
+	if (pkey != UNKNOWN_PKEY)
+		pkey_assert(last_si_pkey == pkey);
+
+#if defined(__i386__) || defined(__x86_64__) /* arch */
+	/*
+	 * The signal handler shold have cleared out PKEY register to let the
+	 * test program continue.  We now have to restore it.
+	 */
+	if (__read_pkey_reg() != 0)
+#else /* arch */
+	if (__read_pkey_reg() != shadow_pkey_reg)
+#endif /* arch */
+		pkey_assert(0);
+
+	__write_pkey_reg(shadow_pkey_reg);
+	dprintf1("%s() set pkey_reg=%016llx to restore state after signal "
+		       "nuked it\n", __func__, shadow_pkey_reg);
+	last_pkey_faults = pkey_faults;
+	last_si_pkey = -1;
+}
+
+#define do_not_expect_pkey_fault(msg)	do {			\
+	if (last_pkey_faults != pkey_faults)			\
+		dprintf0("unexpected PKey fault: %s\n", msg);	\
+	pkey_assert(last_pkey_faults == pkey_faults);		\
+} while (0)
+
+int test_fds[10] = { -1 };
+int nr_test_fds;
+void __save_test_fd(int fd)
+{
+	pkey_assert(fd >= 0);
+	pkey_assert(nr_test_fds < ARRAY_SIZE(test_fds));
+	test_fds[nr_test_fds] = fd;
+	nr_test_fds++;
+}
+
+int get_test_read_fd(void)
+{
+	int test_fd = open("/etc/passwd", O_RDONLY);
+	__save_test_fd(test_fd);
+	return test_fd;
+}
+
+void close_test_fds(void)
+{
+	int i;
+
+	for (i = 0; i < nr_test_fds; i++) {
+		if (test_fds[i] < 0)
+			continue;
+		close(test_fds[i]);
+		test_fds[i] = -1;
+	}
+	nr_test_fds = 0;
+}
+
+#define barrier() __asm__ __volatile__("": : :"memory")
+__attribute__((noinline)) int read_ptr(int *ptr)
+{
+	/*
+	 * Keep GCC from optimizing this away somehow
+	 */
+	barrier();
+	return *ptr;
+}
+
+void test_pkey_alloc_free_attach_pkey0(int *ptr, u16 pkey)
+{
+	int i, err;
+	int max_nr_pkey_allocs;
+	int alloced_pkeys[NR_PKEYS];
+	int nr_alloced = 0;
+	long size;
+
+	pkey_assert(pkey_last_malloc_record);
+	size = pkey_last_malloc_record->size;
+	/*
+	 * This is a bit of a hack.  But mprotect() requires
+	 * huge-page-aligned sizes when operating on hugetlbfs.
+	 * So, make sure that we use something that's a multiple
+	 * of a huge page when we can.
+	 */
+	if (size >= HPAGE_SIZE)
+		size = HPAGE_SIZE;
+
+	/* allocate every possible key and make sure key-0 never got allocated */
+	max_nr_pkey_allocs = NR_PKEYS;
+	for (i = 0; i < max_nr_pkey_allocs; i++) {
+		int new_pkey = alloc_pkey();
+		pkey_assert(new_pkey != 0);
+
+		if (new_pkey < 0)
+			break;
+		alloced_pkeys[nr_alloced++] = new_pkey;
+	}
+	/* free all the allocated keys */
+	for (i = 0; i < nr_alloced; i++) {
+		int free_ret;
+
+		if (!alloced_pkeys[i])
+			continue;
+		free_ret = sys_pkey_free(alloced_pkeys[i]);
+		pkey_assert(!free_ret);
+	}
+
+	/* attach key-0 in various modes */
+	err = sys_mprotect_pkey(ptr, size, PROT_READ, 0);
+	pkey_assert(!err);
+	err = sys_mprotect_pkey(ptr, size, PROT_WRITE, 0);
+	pkey_assert(!err);
+	err = sys_mprotect_pkey(ptr, size, PROT_EXEC, 0);
+	pkey_assert(!err);
+	err = sys_mprotect_pkey(ptr, size, PROT_READ|PROT_WRITE, 0);
+	pkey_assert(!err);
+	err = sys_mprotect_pkey(ptr, size, PROT_READ|PROT_WRITE|PROT_EXEC, 0);
+	pkey_assert(!err);
+}
+
+void test_read_of_write_disabled_region(int *ptr, u16 pkey)
+{
+	int ptr_contents;
+
+	dprintf1("disabling write access to PKEY[1], doing read\n");
+	pkey_write_deny(pkey);
+	ptr_contents = read_ptr(ptr);
+	dprintf1("*ptr: %d\n", ptr_contents);
+	dprintf1("\n");
+}
+void test_read_of_access_disabled_region(int *ptr, u16 pkey)
+{
+	int ptr_contents;
+
+	dprintf1("disabling access to PKEY[%02d], doing read @ %p\n", pkey, ptr);
+	read_pkey_reg();
+	pkey_access_deny(pkey);
+	ptr_contents = read_ptr(ptr);
+	dprintf1("*ptr: %d\n", ptr_contents);
+	expected_pkey_fault(pkey);
+}
+
+void test_read_of_access_disabled_region_with_page_already_mapped(int *ptr,
+		u16 pkey)
+{
+	int ptr_contents;
+
+	dprintf1("disabling access to PKEY[%02d], doing read @ %p\n",
+				pkey, ptr);
+	ptr_contents = read_ptr(ptr);
+	dprintf1("reading ptr before disabling the read : %d\n",
+			ptr_contents);
+	read_pkey_reg();
+	pkey_access_deny(pkey);
+	ptr_contents = read_ptr(ptr);
+	dprintf1("*ptr: %d\n", ptr_contents);
+	expected_pkey_fault(pkey);
+}
+
+void test_write_of_write_disabled_region_with_page_already_mapped(int *ptr,
+		u16 pkey)
+{
+	*ptr = __LINE__;
+	dprintf1("disabling write access; after accessing the page, "
+		"to PKEY[%02d], doing write\n", pkey);
+	pkey_write_deny(pkey);
+	*ptr = __LINE__;
+	expected_pkey_fault(pkey);
+}
+
+void test_write_of_write_disabled_region(int *ptr, u16 pkey)
+{
+	dprintf1("disabling write access to PKEY[%02d], doing write\n", pkey);
+	pkey_write_deny(pkey);
+	*ptr = __LINE__;
+	expected_pkey_fault(pkey);
+}
+void test_write_of_access_disabled_region(int *ptr, u16 pkey)
+{
+	dprintf1("disabling access to PKEY[%02d], doing write\n", pkey);
+	pkey_access_deny(pkey);
+	*ptr = __LINE__;
+	expected_pkey_fault(pkey);
+}
+
+void test_write_of_access_disabled_region_with_page_already_mapped(int *ptr,
+			u16 pkey)
+{
+	*ptr = __LINE__;
+	dprintf1("disabling access; after accessing the page, "
+		" to PKEY[%02d], doing write\n", pkey);
+	pkey_access_deny(pkey);
+	*ptr = __LINE__;
+	expected_pkey_fault(pkey);
+}
+
+void test_kernel_write_of_access_disabled_region(int *ptr, u16 pkey)
+{
+	int ret;
+	int test_fd = get_test_read_fd();
+
+	dprintf1("disabling access to PKEY[%02d], "
+		 "having kernel read() to buffer\n", pkey);
+	pkey_access_deny(pkey);
+	ret = read(test_fd, ptr, 1);
+	dprintf1("read ret: %d\n", ret);
+	pkey_assert(ret);
+}
+void test_kernel_write_of_write_disabled_region(int *ptr, u16 pkey)
+{
+	int ret;
+	int test_fd = get_test_read_fd();
+
+	pkey_write_deny(pkey);
+	ret = read(test_fd, ptr, 100);
+	dprintf1("read ret: %d\n", ret);
+	if (ret < 0 && (DEBUG_LEVEL > 0))
+		perror("verbose read result (OK for this to be bad)");
+	pkey_assert(ret);
+}
+
+void test_kernel_gup_of_access_disabled_region(int *ptr, u16 pkey)
+{
+	int pipe_ret, vmsplice_ret;
+	struct iovec iov;
+	int pipe_fds[2];
+
+	pipe_ret = pipe(pipe_fds);
+
+	pkey_assert(pipe_ret == 0);
+	dprintf1("disabling access to PKEY[%02d], "
+		 "having kernel vmsplice from buffer\n", pkey);
+	pkey_access_deny(pkey);
+	iov.iov_base = ptr;
+	iov.iov_len = PAGE_SIZE;
+	vmsplice_ret = vmsplice(pipe_fds[1], &iov, 1, SPLICE_F_GIFT);
+	dprintf1("vmsplice() ret: %d\n", vmsplice_ret);
+	pkey_assert(vmsplice_ret == -1);
+
+	close(pipe_fds[0]);
+	close(pipe_fds[1]);
+}
+
+void test_kernel_gup_write_to_write_disabled_region(int *ptr, u16 pkey)
+{
+	int ignored = 0xdada;
+	int futex_ret;
+	int some_int = __LINE__;
+
+	dprintf1("disabling write to PKEY[%02d], "
+		 "doing futex gunk in buffer\n", pkey);
+	*ptr = some_int;
+	pkey_write_deny(pkey);
+	futex_ret = syscall(SYS_futex, ptr, FUTEX_WAIT, some_int-1, NULL,
+			&ignored, ignored);
+	if (DEBUG_LEVEL > 0)
+		perror("futex");
+	dprintf1("futex() ret: %d\n", futex_ret);
+}
+
+/* Assumes that all pkeys other than 'pkey' are unallocated */
+void test_pkey_syscalls_on_non_allocated_pkey(int *ptr, u16 pkey)
+{
+	int err;
+	int i;
+
+	/* Note: 0 is the default pkey, so don't mess with it */
+	for (i = 1; i < NR_PKEYS; i++) {
+		if (pkey == i)
+			continue;
+
+		dprintf1("trying get/set/free to non-allocated pkey: %2d\n", i);
+		err = sys_pkey_free(i);
+		pkey_assert(err);
+
+		err = sys_pkey_free(i);
+		pkey_assert(err);
+
+		err = sys_mprotect_pkey(ptr, PAGE_SIZE, PROT_READ, i);
+		pkey_assert(err);
+	}
+}
+
+/* Assumes that all pkeys other than 'pkey' are unallocated */
+void test_pkey_syscalls_bad_args(int *ptr, u16 pkey)
+{
+	int err;
+	int bad_pkey = NR_PKEYS+99;
+
+	/* pass a known-invalid pkey in: */
+	err = sys_mprotect_pkey(ptr, PAGE_SIZE, PROT_READ, bad_pkey);
+	pkey_assert(err);
+}
+
+void become_child(void)
+{
+	pid_t forkret;
+
+	forkret = fork();
+	pkey_assert(forkret >= 0);
+	dprintf3("[%d] fork() ret: %d\n", getpid(), forkret);
+
+	if (!forkret) {
+		/* in the child */
+		return;
+	}
+	exit(0);
+}
+
+/* Assumes that all pkeys other than 'pkey' are unallocated */
+void test_pkey_alloc_exhaust(int *ptr, u16 pkey)
+{
+	int err;
+	int allocated_pkeys[NR_PKEYS] = {0};
+	int nr_allocated_pkeys = 0;
+	int i;
+
+	for (i = 0; i < NR_PKEYS*3; i++) {
+		int new_pkey;
+		dprintf1("%s() alloc loop: %d\n", __func__, i);
+		new_pkey = alloc_pkey();
+		dprintf4("%s()::%d, err: %d pkey_reg: 0x%016llx"
+				" shadow: 0x%016llx\n",
+				__func__, __LINE__, err, __read_pkey_reg(),
+				shadow_pkey_reg);
+		read_pkey_reg(); /* for shadow checking */
+		dprintf2("%s() errno: %d ENOSPC: %d\n", __func__, errno, ENOSPC);
+		if ((new_pkey == -1) && (errno == ENOSPC)) {
+			dprintf2("%s() failed to allocate pkey after %d tries\n",
+				__func__, nr_allocated_pkeys);
+		} else {
+			/*
+			 * Ensure the number of successes never
+			 * exceeds the number of keys supported
+			 * in the hardware.
+			 */
+			pkey_assert(nr_allocated_pkeys < NR_PKEYS);
+			allocated_pkeys[nr_allocated_pkeys++] = new_pkey;
+		}
+
+		/*
+		 * Make sure that allocation state is properly
+		 * preserved across fork().
+		 */
+		if (i == NR_PKEYS*2)
+			become_child();
+	}
+
+	dprintf3("%s()::%d\n", __func__, __LINE__);
+
+	/*
+	 * On x86:
+	 * There are 16 pkeys supported in hardware.  Three are
+	 * allocated by the time we get here:
+	 *   1. The default key (0)
+	 *   2. One possibly consumed by an execute-only mapping.
+	 *   3. One allocated by the test code and passed in via
+	 *      'pkey' to this function.
+	 * Ensure that we can allocate at least another 13 (16-3).
+	 *
+	 * On powerpc:
+	 * There are either 5, 28, 29 or 32 pkeys supported in
+	 * hardware depending on the page size (4K or 64K) and
+	 * platform (powernv or powervm). Four are allocated by
+	 * the time we get here. These include pkey-0, pkey-1,
+	 * exec-only pkey and the one allocated by the test code.
+	 * Ensure that we can allocate the remaining.
+	 */
+	pkey_assert(i >= (NR_PKEYS - get_arch_reserved_keys() - 1));
+
+	for (i = 0; i < nr_allocated_pkeys; i++) {
+		err = sys_pkey_free(allocated_pkeys[i]);
+		pkey_assert(!err);
+		read_pkey_reg(); /* for shadow checking */
+	}
+}
+
+/*
+ * pkey 0 is special.  It is allocated by default, so you do not
+ * have to call pkey_alloc() to use it first.  Make sure that it
+ * is usable.
+ */
+void test_mprotect_with_pkey_0(int *ptr, u16 pkey)
+{
+	long size;
+	int prot;
+
+	assert(pkey_last_malloc_record);
+	size = pkey_last_malloc_record->size;
+	/*
+	 * This is a bit of a hack.  But mprotect() requires
+	 * huge-page-aligned sizes when operating on hugetlbfs.
+	 * So, make sure that we use something that's a multiple
+	 * of a huge page when we can.
+	 */
+	if (size >= HPAGE_SIZE)
+		size = HPAGE_SIZE;
+	prot = pkey_last_malloc_record->prot;
+
+	/* Use pkey 0 */
+	mprotect_pkey(ptr, size, prot, 0);
+
+	/* Make sure that we can set it back to the original pkey. */
+	mprotect_pkey(ptr, size, prot, pkey);
+}
+
+void test_ptrace_of_child(int *ptr, u16 pkey)
+{
+	__attribute__((__unused__)) int peek_result;
+	pid_t child_pid;
+	void *ignored = 0;
+	long ret;
+	int status;
+	/*
+	 * This is the "control" for our little expermient.  Make sure
+	 * we can always access it when ptracing.
+	 */
+	int *plain_ptr_unaligned = malloc(HPAGE_SIZE);
+	int *plain_ptr = ALIGN_PTR_UP(plain_ptr_unaligned, PAGE_SIZE);
+
+	/*
+	 * Fork a child which is an exact copy of this process, of course.
+	 * That means we can do all of our tests via ptrace() and then plain
+	 * memory access and ensure they work differently.
+	 */
+	child_pid = fork_lazy_child();
+	dprintf1("[%d] child pid: %d\n", getpid(), child_pid);
+
+	ret = ptrace(PTRACE_ATTACH, child_pid, ignored, ignored);
+	if (ret)
+		perror("attach");
+	dprintf1("[%d] attach ret: %ld %d\n", getpid(), ret, __LINE__);
+	pkey_assert(ret != -1);
+	ret = waitpid(child_pid, &status, WUNTRACED);
+	if ((ret != child_pid) || !(WIFSTOPPED(status))) {
+		fprintf(stderr, "weird waitpid result %ld stat %x\n",
+				ret, status);
+		pkey_assert(0);
+	}
+	dprintf2("waitpid ret: %ld\n", ret);
+	dprintf2("waitpid status: %d\n", status);
+
+	pkey_access_deny(pkey);
+	pkey_write_deny(pkey);
+
+	/* Write access, untested for now:
+	ret = ptrace(PTRACE_POKEDATA, child_pid, peek_at, data);
+	pkey_assert(ret != -1);
+	dprintf1("poke at %p: %ld\n", peek_at, ret);
+	*/
+
+	/*
+	 * Try to access the pkey-protected "ptr" via ptrace:
+	 */
+	ret = ptrace(PTRACE_PEEKDATA, child_pid, ptr, ignored);
+	/* expect it to work, without an error: */
+	pkey_assert(ret != -1);
+	/* Now access from the current task, and expect an exception: */
+	peek_result = read_ptr(ptr);
+	expected_pkey_fault(pkey);
+
+	/*
+	 * Try to access the NON-pkey-protected "plain_ptr" via ptrace:
+	 */
+	ret = ptrace(PTRACE_PEEKDATA, child_pid, plain_ptr, ignored);
+	/* expect it to work, without an error: */
+	pkey_assert(ret != -1);
+	/* Now access from the current task, and expect NO exception: */
+	peek_result = read_ptr(plain_ptr);
+	do_not_expect_pkey_fault("read plain pointer after ptrace");
+
+	ret = ptrace(PTRACE_DETACH, child_pid, ignored, 0);
+	pkey_assert(ret != -1);
+
+	ret = kill(child_pid, SIGKILL);
+	pkey_assert(ret != -1);
+
+	wait(&status);
+
+	free(plain_ptr_unaligned);
+}
+
+void *get_pointer_to_instructions(void)
+{
+	void *p1;
+
+	p1 = ALIGN_PTR_UP(&lots_o_noops_around_write, PAGE_SIZE);
+	dprintf3("&lots_o_noops: %p\n", &lots_o_noops_around_write);
+	/* lots_o_noops_around_write should be page-aligned already */
+	assert(p1 == &lots_o_noops_around_write);
+
+	/* Point 'p1' at the *second* page of the function: */
+	p1 += PAGE_SIZE;
+
+	/*
+	 * Try to ensure we fault this in on next touch to ensure
+	 * we get an instruction fault as opposed to a data one
+	 */
+	madvise(p1, PAGE_SIZE, MADV_DONTNEED);
+
+	return p1;
+}
+
+void test_executing_on_unreadable_memory(int *ptr, u16 pkey)
+{
+	void *p1;
+	int scratch;
+	int ptr_contents;
+	int ret;
+
+	p1 = get_pointer_to_instructions();
+	lots_o_noops_around_write(&scratch);
+	ptr_contents = read_ptr(p1);
+	dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
+
+	ret = mprotect_pkey(p1, PAGE_SIZE, PROT_EXEC, (u64)pkey);
+	pkey_assert(!ret);
+	pkey_access_deny(pkey);
+
+	dprintf2("pkey_reg: %016llx\n", read_pkey_reg());
+
+	/*
+	 * Make sure this is an *instruction* fault
+	 */
+	madvise(p1, PAGE_SIZE, MADV_DONTNEED);
+	lots_o_noops_around_write(&scratch);
+	do_not_expect_pkey_fault("executing on PROT_EXEC memory");
+	expect_fault_on_read_execonly_key(p1, pkey);
+}
+
+void test_implicit_mprotect_exec_only_memory(int *ptr, u16 pkey)
+{
+	void *p1;
+	int scratch;
+	int ptr_contents;
+	int ret;
+
+	dprintf1("%s() start\n", __func__);
+
+	p1 = get_pointer_to_instructions();
+	lots_o_noops_around_write(&scratch);
+	ptr_contents = read_ptr(p1);
+	dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
+
+	/* Use a *normal* mprotect(), not mprotect_pkey(): */
+	ret = mprotect(p1, PAGE_SIZE, PROT_EXEC);
+	pkey_assert(!ret);
+
+	/*
+	 * Reset the shadow, assuming that the above mprotect()
+	 * correctly changed PKRU, but to an unknown value since
+	 * the actual alllocated pkey is unknown.
+	 */
+	shadow_pkey_reg = __read_pkey_reg();
+
+	dprintf2("pkey_reg: %016llx\n", read_pkey_reg());
+
+	/* Make sure this is an *instruction* fault */
+	madvise(p1, PAGE_SIZE, MADV_DONTNEED);
+	lots_o_noops_around_write(&scratch);
+	do_not_expect_pkey_fault("executing on PROT_EXEC memory");
+	expect_fault_on_read_execonly_key(p1, UNKNOWN_PKEY);
+
+	/*
+	 * Put the memory back to non-PROT_EXEC.  Should clear the
+	 * exec-only pkey off the VMA and allow it to be readable
+	 * again.  Go to PROT_NONE first to check for a kernel bug
+	 * that did not clear the pkey when doing PROT_NONE.
+	 */
+	ret = mprotect(p1, PAGE_SIZE, PROT_NONE);
+	pkey_assert(!ret);
+
+	ret = mprotect(p1, PAGE_SIZE, PROT_READ|PROT_EXEC);
+	pkey_assert(!ret);
+	ptr_contents = read_ptr(p1);
+	do_not_expect_pkey_fault("plain read on recently PROT_EXEC area");
+}
+
+void test_mprotect_pkey_on_unsupported_cpu(int *ptr, u16 pkey)
+{
+	int size = PAGE_SIZE;
+	int sret;
+
+	if (cpu_has_pkeys()) {
+		dprintf1("SKIP: %s: no CPU support\n", __func__);
+		return;
+	}
+
+	sret = syscall(SYS_mprotect_key, ptr, size, PROT_READ, pkey);
+	pkey_assert(sret < 0);
+}
+
+void (*pkey_tests[])(int *ptr, u16 pkey) = {
+	test_read_of_write_disabled_region,
+	test_read_of_access_disabled_region,
+	test_read_of_access_disabled_region_with_page_already_mapped,
+	test_write_of_write_disabled_region,
+	test_write_of_write_disabled_region_with_page_already_mapped,
+	test_write_of_access_disabled_region,
+	test_write_of_access_disabled_region_with_page_already_mapped,
+	test_kernel_write_of_access_disabled_region,
+	test_kernel_write_of_write_disabled_region,
+	test_kernel_gup_of_access_disabled_region,
+	test_kernel_gup_write_to_write_disabled_region,
+	test_executing_on_unreadable_memory,
+	test_implicit_mprotect_exec_only_memory,
+	test_mprotect_with_pkey_0,
+	test_ptrace_of_child,
+	test_pkey_syscalls_on_non_allocated_pkey,
+	test_pkey_syscalls_bad_args,
+	test_pkey_alloc_exhaust,
+	test_pkey_alloc_free_attach_pkey0,
+};
+
+void run_tests_once(void)
+{
+	int *ptr;
+	int prot = PROT_READ|PROT_WRITE;
+
+	for (test_nr = 0; test_nr < ARRAY_SIZE(pkey_tests); test_nr++) {
+		int pkey;
+		int orig_pkey_faults = pkey_faults;
+
+		dprintf1("======================\n");
+		dprintf1("test %d preparing...\n", test_nr);
+
+		tracing_on();
+		pkey = alloc_random_pkey();
+		dprintf1("test %d starting with pkey: %d\n", test_nr, pkey);
+		ptr = malloc_pkey(PAGE_SIZE, prot, pkey);
+		dprintf1("test %d starting...\n", test_nr);
+		pkey_tests[test_nr](ptr, pkey);
+		dprintf1("freeing test memory: %p\n", ptr);
+		free_pkey_malloc(ptr);
+		sys_pkey_free(pkey);
+
+		dprintf1("pkey_faults: %d\n", pkey_faults);
+		dprintf1("orig_pkey_faults: %d\n", orig_pkey_faults);
+
+		tracing_off();
+		close_test_fds();
+
+		printf("test %2d PASSED (iteration %d)\n", test_nr, iteration_nr);
+		dprintf1("======================\n\n");
+	}
+	iteration_nr++;
+}
+
+void pkey_setup_shadow(void)
+{
+	shadow_pkey_reg = __read_pkey_reg();
+}
+
+int main(void)
+{
+	int nr_iterations = 22;
+	int pkeys_supported = is_pkeys_supported();
+
+	srand((unsigned int)time(NULL));
+
+	setup_handlers();
+
+	printf("has pkeys: %d\n", pkeys_supported);
+
+	if (!pkeys_supported) {
+		int size = PAGE_SIZE;
+		int *ptr;
+
+		printf("running PKEY tests for unsupported CPU/OS\n");
+
+		ptr  = mmap(NULL, size, PROT_NONE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
+		assert(ptr != (void *)-1);
+		test_mprotect_pkey_on_unsupported_cpu(ptr, 1);
+		exit(0);
+	}
+
+	pkey_setup_shadow();
+	printf("startup pkey_reg: %016llx\n", read_pkey_reg());
+	setup_hugetlbfs();
+
+	while (nr_iterations-- > 0)
+		run_tests_once();
+
+	printf("done (all tests OK)\n");
+	return 0;
+}
diff --git a/tools/testing/selftests/vm/run_vmtests b/tools/testing/selftests/vm/run_vmtests
index 951c507..a3f4f30 100755
--- a/tools/testing/selftests/vm/run_vmtests
+++ b/tools/testing/selftests/vm/run_vmtests
@@ -58,6 +58,14 @@
 	exit 1
 fi
 
+#filter 64bit architectures
+ARCH64STR="arm64 ia64 mips64 parisc64 ppc64 ppc64le riscv64 s390x sh64 sparc64 x86_64"
+if [ -z $ARCH ]; then
+  ARCH=`uname -m 2>/dev/null | sed -e 's/aarch64.*/arm64/'`
+fi
+VADDR64=0
+echo "$ARCH64STR" | grep $ARCH && VADDR64=1
+
 mkdir $mnt
 mount -t hugetlbfs none $mnt
 
@@ -104,6 +112,39 @@
 echo "      https://github.com/libhugetlbfs/libhugetlbfs.git for"
 echo "      hugetlb regression testing."
 
+echo "---------------------------"
+echo "running map_fixed_noreplace"
+echo "---------------------------"
+./map_fixed_noreplace
+if [ $? -ne 0 ]; then
+	echo "[FAIL]"
+	exitcode=1
+else
+	echo "[PASS]"
+fi
+
+echo "--------------------------------------------"
+echo "running 'gup_benchmark -U' (normal/slow gup)"
+echo "--------------------------------------------"
+./gup_benchmark -U
+if [ $? -ne 0 ]; then
+	echo "[FAIL]"
+	exitcode=1
+else
+	echo "[PASS]"
+fi
+
+echo "------------------------------------------"
+echo "running gup_benchmark -b (pin_user_pages)"
+echo "------------------------------------------"
+./gup_benchmark -b
+if [ $? -ne 0 ]; then
+	echo "[FAIL]"
+	exitcode=1
+else
+	echo "[PASS]"
+fi
+
 echo "-------------------"
 echo "running userfaultfd"
 echo "-------------------"
@@ -178,6 +219,17 @@
 	echo "[PASS]"
 fi
 
+echo "-------------------------"
+echo "running mlock-random-test"
+echo "-------------------------"
+./mlock-random-test
+if [ $? -ne 0 ]; then
+	echo "[FAIL]"
+	exitcode=1
+else
+	echo "[PASS]"
+fi
+
 echo "--------------------"
 echo "running mlock2-tests"
 echo "--------------------"
@@ -189,6 +241,18 @@
 	echo "[PASS]"
 fi
 
+echo "-----------------"
+echo "running thuge-gen"
+echo "-----------------"
+./thuge-gen
+if [ $? -ne 0 ]; then
+	echo "[FAIL]"
+	exitcode=1
+else
+	echo "[PASS]"
+fi
+
+if [ $VADDR64 -ne 0 ]; then
 echo "-----------------------------"
 echo "running virtual_address_range"
 echo "-----------------------------"
@@ -210,6 +274,7 @@
 else
     echo "[PASS]"
 fi
+fi # VADDR64
 
 echo "------------------------------------"
 echo "running vmalloc stability smoke test"
@@ -227,4 +292,35 @@
 	exitcode=1
 fi
 
+echo "------------------------------------"
+echo "running MREMAP_DONTUNMAP smoke test"
+echo "------------------------------------"
+./mremap_dontunmap
+ret_val=$?
+
+if [ $ret_val -eq 0 ]; then
+	echo "[PASS]"
+elif [ $ret_val -eq $ksft_skip ]; then
+	 echo "[SKIP]"
+	 exitcode=$ksft_skip
+else
+	echo "[FAIL]"
+	exitcode=1
+fi
+
+echo "running HMM smoke test"
+echo "------------------------------------"
+./test_hmm.sh smoke
+ret_val=$?
+
+if [ $ret_val -eq 0 ]; then
+	echo "[PASS]"
+elif [ $ret_val -eq $ksft_skip ]; then
+	echo "[SKIP]"
+	exitcode=$ksft_skip
+else
+	echo "[FAIL]"
+	exitcode=1
+fi
+
 exit $exitcode
diff --git a/tools/testing/selftests/vm/test_hmm.sh b/tools/testing/selftests/vm/test_hmm.sh
new file mode 100755
index 0000000..0647b52
--- /dev/null
+++ b/tools/testing/selftests/vm/test_hmm.sh
@@ -0,0 +1,97 @@
+#!/bin/bash
+# SPDX-License-Identifier: GPL-2.0
+#
+# Copyright (C) 2018 Uladzislau Rezki (Sony) <urezki@gmail.com>
+#
+# This is a test script for the kernel test driver to analyse vmalloc
+# allocator. Therefore it is just a kernel module loader. You can specify
+# and pass different parameters in order to:
+#     a) analyse performance of vmalloc allocations;
+#     b) stressing and stability check of vmalloc subsystem.
+
+TEST_NAME="test_hmm"
+DRIVER="test_hmm"
+
+# 1 if fails
+exitcode=1
+
+# Kselftest framework requirement - SKIP code is 4.
+ksft_skip=4
+
+check_test_requirements()
+{
+	uid=$(id -u)
+	if [ $uid -ne 0 ]; then
+		echo "$0: Must be run as root"
+		exit $ksft_skip
+	fi
+
+	if ! which modprobe > /dev/null 2>&1; then
+		echo "$0: You need modprobe installed"
+		exit $ksft_skip
+	fi
+
+	if ! modinfo $DRIVER > /dev/null 2>&1; then
+		echo "$0: You must have the following enabled in your kernel:"
+		echo "CONFIG_TEST_HMM=m"
+		exit $ksft_skip
+	fi
+}
+
+load_driver()
+{
+	modprobe $DRIVER > /dev/null 2>&1
+	if [ $? == 0 ]; then
+		major=$(awk "\$2==\"HMM_DMIRROR\" {print \$1}" /proc/devices)
+		mknod /dev/hmm_dmirror0 c $major 0
+		mknod /dev/hmm_dmirror1 c $major 1
+	fi
+}
+
+unload_driver()
+{
+	modprobe -r $DRIVER > /dev/null 2>&1
+	rm -f /dev/hmm_dmirror?
+}
+
+run_smoke()
+{
+	echo "Running smoke test. Note, this test provides basic coverage."
+
+	load_driver
+	$(dirname "${BASH_SOURCE[0]}")/hmm-tests
+	unload_driver
+}
+
+usage()
+{
+	echo -n "Usage: $0"
+	echo
+	echo "Example usage:"
+	echo
+	echo "# Shows help message"
+	echo "./${TEST_NAME}.sh"
+	echo
+	echo "# Smoke testing"
+	echo "./${TEST_NAME}.sh smoke"
+	echo
+	exit 0
+}
+
+function run_test()
+{
+	if [ $# -eq 0 ]; then
+		usage
+	else
+		if [ "$1" = "smoke" ]; then
+			run_smoke
+		else
+			usage
+		fi
+	fi
+}
+
+check_test_requirements
+run_test $@
+
+exit 0
diff --git a/tools/testing/selftests/vm/userfaultfd.c b/tools/testing/selftests/vm/userfaultfd.c
index 9ba7fef..034245e 100644
--- a/tools/testing/selftests/vm/userfaultfd.c
+++ b/tools/testing/selftests/vm/userfaultfd.c
@@ -46,6 +46,7 @@
 #include <signal.h>
 #include <poll.h>
 #include <string.h>
+#include <linux/mman.h>
 #include <sys/mman.h>
 #include <sys/syscall.h>
 #include <sys/ioctl.h>
@@ -54,6 +55,7 @@
 #include <linux/userfaultfd.h>
 #include <setjmp.h>
 #include <stdbool.h>
+#include <assert.h>
 
 #include "../kselftest.h"
 
@@ -76,6 +78,8 @@
 #define ALARM_INTERVAL_SECS 10
 static volatile bool test_uffdio_copy_eexist = true;
 static volatile bool test_uffdio_zeropage_eexist = true;
+/* Whether to test uffd write-protection */
+static bool test_uffdio_wp = false;
 
 static bool map_shared;
 static int huge_fd;
@@ -86,6 +90,13 @@
 static char *zeropage;
 pthread_attr_t attr;
 
+/* Userfaultfd test statistics */
+struct uffd_stats {
+	int cpu;
+	unsigned long missing_faults;
+	unsigned long wp_faults;
+};
+
 /* pthread_mutex_t starts at page offset 0 */
 #define area_mutex(___area, ___nr)					\
 	((pthread_mutex_t *) ((___area) + (___nr)*page_size))
@@ -125,6 +136,37 @@
 	exit(1);
 }
 
+static void uffd_stats_reset(struct uffd_stats *uffd_stats,
+			     unsigned long n_cpus)
+{
+	int i;
+
+	for (i = 0; i < n_cpus; i++) {
+		uffd_stats[i].cpu = i;
+		uffd_stats[i].missing_faults = 0;
+		uffd_stats[i].wp_faults = 0;
+	}
+}
+
+static void uffd_stats_report(struct uffd_stats *stats, int n_cpus)
+{
+	int i;
+	unsigned long long miss_total = 0, wp_total = 0;
+
+	for (i = 0; i < n_cpus; i++) {
+		miss_total += stats[i].missing_faults;
+		wp_total += stats[i].wp_faults;
+	}
+
+	printf("userfaults: %llu missing (", miss_total);
+	for (i = 0; i < n_cpus; i++)
+		printf("%lu+", stats[i].missing_faults);
+	printf("\b), %llu wp (", wp_total);
+	for (i = 0; i < n_cpus; i++)
+		printf("%lu+", stats[i].wp_faults);
+	printf("\b)\n");
+}
+
 static int anon_release_pages(char *rel_area)
 {
 	int ret = 0;
@@ -167,19 +209,19 @@
 	return ret;
 }
 
-
 static void hugetlb_allocate_area(void **alloc_area)
 {
 	void *area_alias = NULL;
 	char **alloc_area_alias;
+
 	*alloc_area = mmap(NULL, nr_pages * page_size, PROT_READ | PROT_WRITE,
 			   (map_shared ? MAP_SHARED : MAP_PRIVATE) |
 			   MAP_HUGETLB,
 			   huge_fd, *alloc_area == area_src ? 0 :
 			   nr_pages * page_size);
 	if (*alloc_area == MAP_FAILED) {
-		fprintf(stderr, "mmap of hugetlbfs file failed\n");
-		*alloc_area = NULL;
+		perror("mmap of hugetlbfs file failed");
+		goto fail;
 	}
 
 	if (map_shared) {
@@ -188,12 +230,11 @@
 				  huge_fd, *alloc_area == area_src ? 0 :
 				  nr_pages * page_size);
 		if (area_alias == MAP_FAILED) {
-			if (munmap(*alloc_area, nr_pages * page_size) < 0)
-				perror("hugetlb munmap"), exit(1);
-			*alloc_area = NULL;
-			return;
+			perror("mmap of hugetlb file alias failed");
+			goto fail_munmap;
 		}
 	}
+
 	if (*alloc_area == area_src) {
 		huge_fd_off0 = *alloc_area;
 		alloc_area_alias = &area_src_alias;
@@ -202,6 +243,16 @@
 	}
 	if (area_alias)
 		*alloc_area_alias = area_alias;
+
+	return;
+
+fail_munmap:
+	if (munmap(*alloc_area, nr_pages * page_size) < 0) {
+		perror("hugetlb munmap");
+		exit(1);
+	}
+fail:
+	*alloc_area = NULL;
 }
 
 static void hugetlb_alias_mapping(__u64 *start, size_t len, unsigned long offset)
@@ -247,10 +298,15 @@
 	void (*alias_mapping)(__u64 *start, size_t len, unsigned long offset);
 };
 
-#define ANON_EXPECTED_IOCTLS		((1 << _UFFDIO_WAKE) | \
+#define SHMEM_EXPECTED_IOCTLS		((1 << _UFFDIO_WAKE) | \
 					 (1 << _UFFDIO_COPY) | \
 					 (1 << _UFFDIO_ZEROPAGE))
 
+#define ANON_EXPECTED_IOCTLS		((1 << _UFFDIO_WAKE) | \
+					 (1 << _UFFDIO_COPY) | \
+					 (1 << _UFFDIO_ZEROPAGE) | \
+					 (1 << _UFFDIO_WRITEPROTECT))
+
 static struct uffd_test_ops anon_uffd_test_ops = {
 	.expected_ioctls = ANON_EXPECTED_IOCTLS,
 	.allocate_area	= anon_allocate_area,
@@ -259,7 +315,7 @@
 };
 
 static struct uffd_test_ops shmem_uffd_test_ops = {
-	.expected_ioctls = ANON_EXPECTED_IOCTLS,
+	.expected_ioctls = SHMEM_EXPECTED_IOCTLS,
 	.allocate_area	= shmem_allocate_area,
 	.release_pages	= shmem_release_pages,
 	.alias_mapping = noop_alias_mapping,
@@ -283,6 +339,22 @@
 	return 0;
 }
 
+static void wp_range(int ufd, __u64 start, __u64 len, bool wp)
+{
+	struct uffdio_writeprotect prms = { 0 };
+
+	/* Write protection page faults */
+	prms.range.start = start;
+	prms.range.len = len;
+	/* Undo write-protect, do wakeup after that */
+	prms.mode = wp ? UFFDIO_WRITEPROTECT_MODE_WP : 0;
+
+	if (ioctl(ufd, UFFDIO_WRITEPROTECT, &prms)) {
+		fprintf(stderr, "clear WP failed for address 0x%Lx\n", start);
+		exit(1);
+	}
+}
+
 static void *locking_thread(void *arg)
 {
 	unsigned long cpu = (unsigned long) arg;
@@ -300,8 +372,10 @@
 			seed += cpu;
 		bzero(&rand, sizeof(rand));
 		bzero(&randstate, sizeof(randstate));
-		if (initstate_r(seed, randstate, sizeof(randstate), &rand))
-			fprintf(stderr, "srandom_r error\n"), exit(1);
+		if (initstate_r(seed, randstate, sizeof(randstate), &rand)) {
+			fprintf(stderr, "srandom_r error\n");
+			exit(1);
+		}
 	} else {
 		page_nr = -bounces;
 		if (!(bounces & BOUNCE_RACINGFAULTS))
@@ -310,12 +384,16 @@
 
 	while (!finished) {
 		if (bounces & BOUNCE_RANDOM) {
-			if (random_r(&rand, &rand_nr))
-				fprintf(stderr, "random_r 1 error\n"), exit(1);
+			if (random_r(&rand, &rand_nr)) {
+				fprintf(stderr, "random_r 1 error\n");
+				exit(1);
+			}
 			page_nr = rand_nr;
 			if (sizeof(page_nr) > sizeof(rand_nr)) {
-				if (random_r(&rand, &rand_nr))
-					fprintf(stderr, "random_r 2 error\n"), exit(1);
+				if (random_r(&rand, &rand_nr)) {
+					fprintf(stderr, "random_r 2 error\n");
+					exit(1);
+				}
 				page_nr |= (((unsigned long) rand_nr) << 16) <<
 					   16;
 			}
@@ -326,11 +404,13 @@
 		start = time(NULL);
 		if (bounces & BOUNCE_VERIFY) {
 			count = *area_count(area_dst, page_nr);
-			if (!count)
+			if (!count) {
 				fprintf(stderr,
 					"page_nr %lu wrong count %Lu %Lu\n",
 					page_nr, count,
-					count_verify[page_nr]), exit(1);
+					count_verify[page_nr]);
+				exit(1);
+			}
 
 
 			/*
@@ -342,11 +422,12 @@
 			 */
 #if 1
 			if (!my_bcmp(area_dst + page_nr * page_size, zeropage,
-				     page_size))
+				     page_size)) {
 				fprintf(stderr,
 					"my_bcmp page_nr %lu wrong count %Lu %Lu\n",
-					page_nr, count,
-					count_verify[page_nr]), exit(1);
+					page_nr, count, count_verify[page_nr]);
+				exit(1);
+			}
 #else
 			unsigned long loops;
 
@@ -378,7 +459,7 @@
 			fprintf(stderr,
 				"page_nr %lu memory corruption %Lu %Lu\n",
 				page_nr, count,
-				count_verify[page_nr]), exit(1);
+				count_verify[page_nr]); exit(1);
 		}
 		count++;
 		*area_count(area_dst, page_nr) = count_verify[page_nr] = count;
@@ -402,12 +483,14 @@
 				     offset);
 	if (ioctl(ufd, UFFDIO_COPY, uffdio_copy)) {
 		/* real retval in ufdio_copy.copy */
-		if (uffdio_copy->copy != -EEXIST)
+		if (uffdio_copy->copy != -EEXIST) {
 			fprintf(stderr, "UFFDIO_COPY retry error %Ld\n",
-				uffdio_copy->copy), exit(1);
+				uffdio_copy->copy);
+			exit(1);
+		}
 	} else {
 		fprintf(stderr,	"UFFDIO_COPY retry unexpected %Ld\n",
-			uffdio_copy->copy), exit(1);
+			uffdio_copy->copy); exit(1);
 	}
 }
 
@@ -415,22 +498,28 @@
 {
 	struct uffdio_copy uffdio_copy;
 
-	if (offset >= nr_pages * page_size)
-		fprintf(stderr, "unexpected offset %lu\n",
-			offset), exit(1);
+	if (offset >= nr_pages * page_size) {
+		fprintf(stderr, "unexpected offset %lu\n", offset);
+		exit(1);
+	}
 	uffdio_copy.dst = (unsigned long) area_dst + offset;
 	uffdio_copy.src = (unsigned long) area_src + offset;
 	uffdio_copy.len = page_size;
-	uffdio_copy.mode = 0;
+	if (test_uffdio_wp)
+		uffdio_copy.mode = UFFDIO_COPY_MODE_WP;
+	else
+		uffdio_copy.mode = 0;
 	uffdio_copy.copy = 0;
 	if (ioctl(ufd, UFFDIO_COPY, &uffdio_copy)) {
 		/* real retval in ufdio_copy.copy */
-		if (uffdio_copy.copy != -EEXIST)
+		if (uffdio_copy.copy != -EEXIST) {
 			fprintf(stderr, "UFFDIO_COPY error %Ld\n",
-				uffdio_copy.copy), exit(1);
+				uffdio_copy.copy);
+			exit(1);
+		}
 	} else if (uffdio_copy.copy != page_size) {
 		fprintf(stderr, "UFFDIO_COPY unexpected copy %Ld\n",
-			uffdio_copy.copy), exit(1);
+			uffdio_copy.copy); exit(1);
 	} else {
 		if (test_uffdio_copy_eexist && retry) {
 			test_uffdio_copy_eexist = false;
@@ -459,44 +548,54 @@
 		if (ret < 0) {
 			if (errno == EAGAIN)
 				return 1;
-			else
-				perror("blocking read error"), exit(1);
+			perror("blocking read error");
 		} else {
-			fprintf(stderr, "short read\n"), exit(1);
+			fprintf(stderr, "short read\n");
 		}
+		exit(1);
 	}
 
 	return 0;
 }
 
-/* Return 1 if page fault handled by us; otherwise 0 */
-static int uffd_handle_page_fault(struct uffd_msg *msg)
+static void uffd_handle_page_fault(struct uffd_msg *msg,
+				   struct uffd_stats *stats)
 {
 	unsigned long offset;
 
-	if (msg->event != UFFD_EVENT_PAGEFAULT)
-		fprintf(stderr, "unexpected msg event %u\n",
-			msg->event), exit(1);
+	if (msg->event != UFFD_EVENT_PAGEFAULT) {
+		fprintf(stderr, "unexpected msg event %u\n", msg->event);
+		exit(1);
+	}
 
-	if (bounces & BOUNCE_VERIFY &&
-	    msg->arg.pagefault.flags & UFFD_PAGEFAULT_FLAG_WRITE)
-		fprintf(stderr, "unexpected write fault\n"), exit(1);
+	if (msg->arg.pagefault.flags & UFFD_PAGEFAULT_FLAG_WP) {
+		wp_range(uffd, msg->arg.pagefault.address, page_size, false);
+		stats->wp_faults++;
+	} else {
+		/* Missing page faults */
+		if (bounces & BOUNCE_VERIFY &&
+		    msg->arg.pagefault.flags & UFFD_PAGEFAULT_FLAG_WRITE) {
+			fprintf(stderr, "unexpected write fault\n");
+			exit(1);
+		}
 
-	offset = (char *)(unsigned long)msg->arg.pagefault.address - area_dst;
-	offset &= ~(page_size-1);
+		offset = (char *)(unsigned long)msg->arg.pagefault.address - area_dst;
+		offset &= ~(page_size-1);
 
-	return copy_page(uffd, offset);
+		if (copy_page(uffd, offset))
+			stats->missing_faults++;
+	}
 }
 
 static void *uffd_poll_thread(void *arg)
 {
-	unsigned long cpu = (unsigned long) arg;
+	struct uffd_stats *stats = (struct uffd_stats *)arg;
+	unsigned long cpu = stats->cpu;
 	struct pollfd pollfd[2];
 	struct uffd_msg msg;
 	struct uffdio_register uffd_reg;
 	int ret;
 	char tmp_chr;
-	unsigned long userfaults = 0;
 
 	pollfd[0].fd = uffd;
 	pollfd[0].events = POLLIN;
@@ -505,28 +604,35 @@
 
 	for (;;) {
 		ret = poll(pollfd, 2, -1);
-		if (!ret)
-			fprintf(stderr, "poll error %d\n", ret), exit(1);
-		if (ret < 0)
-			perror("poll"), exit(1);
+		if (!ret) {
+			fprintf(stderr, "poll error %d\n", ret);
+			exit(1);
+		}
+		if (ret < 0) {
+			perror("poll");
+			exit(1);
+		}
 		if (pollfd[1].revents & POLLIN) {
-			if (read(pollfd[1].fd, &tmp_chr, 1) != 1)
-				fprintf(stderr, "read pipefd error\n"),
-					exit(1);
+			if (read(pollfd[1].fd, &tmp_chr, 1) != 1) {
+				fprintf(stderr, "read pipefd error\n");
+				exit(1);
+			}
 			break;
 		}
-		if (!(pollfd[0].revents & POLLIN))
+		if (!(pollfd[0].revents & POLLIN)) {
 			fprintf(stderr, "pollfd[0].revents %d\n",
-				pollfd[0].revents), exit(1);
+				pollfd[0].revents);
+			exit(1);
+		}
 		if (uffd_read_msg(uffd, &msg))
 			continue;
 		switch (msg.event) {
 		default:
 			fprintf(stderr, "unexpected msg event %u\n",
-				msg.event), exit(1);
+				msg.event); exit(1);
 			break;
 		case UFFD_EVENT_PAGEFAULT:
-			userfaults += uffd_handle_page_fault(&msg);
+			uffd_handle_page_fault(&msg, stats);
 			break;
 		case UFFD_EVENT_FORK:
 			close(uffd);
@@ -537,58 +643,77 @@
 			uffd_reg.range.start = msg.arg.remove.start;
 			uffd_reg.range.len = msg.arg.remove.end -
 				msg.arg.remove.start;
-			if (ioctl(uffd, UFFDIO_UNREGISTER, &uffd_reg.range))
-				fprintf(stderr, "remove failure\n"), exit(1);
+			if (ioctl(uffd, UFFDIO_UNREGISTER, &uffd_reg.range)) {
+				fprintf(stderr, "remove failure\n");
+				exit(1);
+			}
 			break;
 		case UFFD_EVENT_REMAP:
 			area_dst = (char *)(unsigned long)msg.arg.remap.to;
 			break;
 		}
 	}
-	return (void *)userfaults;
+
+	return NULL;
 }
 
 pthread_mutex_t uffd_read_mutex = PTHREAD_MUTEX_INITIALIZER;
 
 static void *uffd_read_thread(void *arg)
 {
-	unsigned long *this_cpu_userfaults;
+	struct uffd_stats *stats = (struct uffd_stats *)arg;
 	struct uffd_msg msg;
 
-	this_cpu_userfaults = (unsigned long *) arg;
-	*this_cpu_userfaults = 0;
-
 	pthread_mutex_unlock(&uffd_read_mutex);
 	/* from here cancellation is ok */
 
 	for (;;) {
 		if (uffd_read_msg(uffd, &msg))
 			continue;
-		(*this_cpu_userfaults) += uffd_handle_page_fault(&msg);
+		uffd_handle_page_fault(&msg, stats);
 	}
-	return (void *)NULL;
+
+	return NULL;
 }
 
 static void *background_thread(void *arg)
 {
 	unsigned long cpu = (unsigned long) arg;
-	unsigned long page_nr;
+	unsigned long page_nr, start_nr, mid_nr, end_nr;
 
-	for (page_nr = cpu * nr_pages_per_cpu;
-	     page_nr < (cpu+1) * nr_pages_per_cpu;
-	     page_nr++)
+	start_nr = cpu * nr_pages_per_cpu;
+	end_nr = (cpu+1) * nr_pages_per_cpu;
+	mid_nr = (start_nr + end_nr) / 2;
+
+	/* Copy the first half of the pages */
+	for (page_nr = start_nr; page_nr < mid_nr; page_nr++)
+		copy_page_retry(uffd, page_nr * page_size);
+
+	/*
+	 * If we need to test uffd-wp, set it up now.  Then we'll have
+	 * at least the first half of the pages mapped already which
+	 * can be write-protected for testing
+	 */
+	if (test_uffdio_wp)
+		wp_range(uffd, (unsigned long)area_dst + start_nr * page_size,
+			nr_pages_per_cpu * page_size, true);
+
+	/*
+	 * Continue the 2nd half of the page copying, handling write
+	 * protection faults if any
+	 */
+	for (page_nr = mid_nr; page_nr < end_nr; page_nr++)
 		copy_page_retry(uffd, page_nr * page_size);
 
 	return NULL;
 }
 
-static int stress(unsigned long *userfaults)
+static int stress(struct uffd_stats *uffd_stats)
 {
 	unsigned long cpu;
 	pthread_t locking_threads[nr_cpus];
 	pthread_t uffd_threads[nr_cpus];
 	pthread_t background_threads[nr_cpus];
-	void **_userfaults = (void **) userfaults;
 
 	finished = 0;
 	for (cpu = 0; cpu < nr_cpus; cpu++) {
@@ -597,12 +722,13 @@
 			return 1;
 		if (bounces & BOUNCE_POLL) {
 			if (pthread_create(&uffd_threads[cpu], &attr,
-					   uffd_poll_thread, (void *)cpu))
+					   uffd_poll_thread,
+					   (void *)&uffd_stats[cpu]))
 				return 1;
 		} else {
 			if (pthread_create(&uffd_threads[cpu], &attr,
 					   uffd_read_thread,
-					   &_userfaults[cpu]))
+					   (void *)&uffd_stats[cpu]))
 				return 1;
 			pthread_mutex_lock(&uffd_read_mutex);
 		}
@@ -639,7 +765,8 @@
 				fprintf(stderr, "pipefd write error\n");
 				return 1;
 			}
-			if (pthread_join(uffd_threads[cpu], &_userfaults[cpu]))
+			if (pthread_join(uffd_threads[cpu],
+					 (void *)&uffd_stats[cpu]))
 				return 1;
 		} else {
 			if (pthread_cancel(uffd_threads[cpu]))
@@ -737,17 +864,31 @@
 	}
 
 	for (nr = 0; nr < split_nr_pages; nr++) {
+		int steps = 1;
+		unsigned long offset = nr * page_size;
+
 		if (signal_test) {
 			if (sigsetjmp(*sigbuf, 1) != 0) {
-				if (nr == lastnr) {
+				if (steps == 1 && nr == lastnr) {
 					fprintf(stderr, "Signal repeated\n");
 					return 1;
 				}
 
 				lastnr = nr;
 				if (signal_test == 1) {
-					if (copy_page(uffd, nr * page_size))
-						signalled++;
+					if (steps == 1) {
+						/* This is a MISSING request */
+						steps++;
+						if (copy_page(uffd, offset))
+							signalled++;
+					} else {
+						/* This is a WP request */
+						assert(steps == 2);
+						wp_range(uffd,
+							 (__u64)area_dst +
+							 offset,
+							 page_size, false);
+					}
 				} else {
 					signalled++;
 					continue;
@@ -760,8 +901,13 @@
 			fprintf(stderr,
 				"nr %lu memory corruption %Lu %Lu\n",
 				nr, count,
-				count_verify[nr]), exit(1);
-		}
+				count_verify[nr]);
+	        }
+		/*
+		 * Trigger write protection if there is by writting
+		 * the same value back.
+		 */
+		*area_count(area_dst, nr) = count;
 	}
 
 	if (signal_test)
@@ -772,8 +918,10 @@
 
 	area_dst = mremap(area_dst, nr_pages * page_size,  nr_pages * page_size,
 			  MREMAP_MAYMOVE | MREMAP_FIXED, area_src);
-	if (area_dst == MAP_FAILED)
-		perror("mremap"), exit(1);
+	if (area_dst == MAP_FAILED) {
+		perror("mremap");
+		exit(1);
+	}
 
 	for (; nr < nr_pages; nr++) {
 		count = *area_count(area_dst, nr);
@@ -781,16 +929,23 @@
 			fprintf(stderr,
 				"nr %lu memory corruption %Lu %Lu\n",
 				nr, count,
-				count_verify[nr]), exit(1);
+				count_verify[nr]); exit(1);
 		}
+		/*
+		 * Trigger write protection if there is by writting
+		 * the same value back.
+		 */
+		*area_count(area_dst, nr) = count;
 	}
 
 	if (uffd_test_ops->release_pages(area_dst))
 		return 1;
 
 	for (nr = 0; nr < nr_pages; nr++) {
-		if (my_bcmp(area_dst + nr * page_size, zeropage, page_size))
-			fprintf(stderr, "nr %lu is not zero\n", nr), exit(1);
+		if (my_bcmp(area_dst + nr * page_size, zeropage, page_size)) {
+			fprintf(stderr, "nr %lu is not zero\n", nr);
+			exit(1);
+		}
 	}
 
 	return 0;
@@ -804,12 +959,14 @@
 				     uffdio_zeropage->range.len,
 				     offset);
 	if (ioctl(ufd, UFFDIO_ZEROPAGE, uffdio_zeropage)) {
-		if (uffdio_zeropage->zeropage != -EEXIST)
+		if (uffdio_zeropage->zeropage != -EEXIST) {
 			fprintf(stderr, "UFFDIO_ZEROPAGE retry error %Ld\n",
-				uffdio_zeropage->zeropage), exit(1);
+				uffdio_zeropage->zeropage);
+			exit(1);
+		}
 	} else {
 		fprintf(stderr, "UFFDIO_ZEROPAGE retry unexpected %Ld\n",
-			uffdio_zeropage->zeropage), exit(1);
+			uffdio_zeropage->zeropage); exit(1);
 	}
 }
 
@@ -821,9 +978,10 @@
 
 	has_zeropage = uffd_test_ops->expected_ioctls & (1 << _UFFDIO_ZEROPAGE);
 
-	if (offset >= nr_pages * page_size)
-		fprintf(stderr, "unexpected offset %lu\n",
-			offset), exit(1);
+	if (offset >= nr_pages * page_size) {
+		fprintf(stderr, "unexpected offset %lu\n", offset);
+		exit(1);
+	}
 	uffdio_zeropage.range.start = (unsigned long) area_dst + offset;
 	uffdio_zeropage.range.len = page_size;
 	uffdio_zeropage.mode = 0;
@@ -831,22 +989,26 @@
 	if (ret) {
 		/* real retval in ufdio_zeropage.zeropage */
 		if (has_zeropage) {
-			if (uffdio_zeropage.zeropage == -EEXIST)
-				fprintf(stderr, "UFFDIO_ZEROPAGE -EEXIST\n"),
-					exit(1);
-			else
+			if (uffdio_zeropage.zeropage == -EEXIST) {
+				fprintf(stderr, "UFFDIO_ZEROPAGE -EEXIST\n");
+				exit(1);
+			} else {
 				fprintf(stderr, "UFFDIO_ZEROPAGE error %Ld\n",
-					uffdio_zeropage.zeropage), exit(1);
+					uffdio_zeropage.zeropage);
+				exit(1);
+			}
 		} else {
-			if (uffdio_zeropage.zeropage != -EINVAL)
+			if (uffdio_zeropage.zeropage != -EINVAL) {
 				fprintf(stderr,
 					"UFFDIO_ZEROPAGE not -EINVAL %Ld\n",
-					uffdio_zeropage.zeropage), exit(1);
+					uffdio_zeropage.zeropage);
+				exit(1);
+			}
 		}
 	} else if (has_zeropage) {
 		if (uffdio_zeropage.zeropage != page_size) {
 			fprintf(stderr, "UFFDIO_ZEROPAGE unexpected %Ld\n",
-				uffdio_zeropage.zeropage), exit(1);
+				uffdio_zeropage.zeropage); exit(1);
 		} else {
 			if (test_uffdio_zeropage_eexist && retry) {
 				test_uffdio_zeropage_eexist = false;
@@ -858,7 +1020,7 @@
 	} else {
 		fprintf(stderr,
 			"UFFDIO_ZEROPAGE succeeded %Ld\n",
-			uffdio_zeropage.zeropage), exit(1);
+			uffdio_zeropage.zeropage); exit(1);
 	}
 
 	return 0;
@@ -886,19 +1048,26 @@
 	uffdio_register.range.start = (unsigned long) area_dst;
 	uffdio_register.range.len = nr_pages * page_size;
 	uffdio_register.mode = UFFDIO_REGISTER_MODE_MISSING;
-	if (ioctl(uffd, UFFDIO_REGISTER, &uffdio_register))
-		fprintf(stderr, "register failure\n"), exit(1);
+	if (test_uffdio_wp)
+		uffdio_register.mode |= UFFDIO_REGISTER_MODE_WP;
+	if (ioctl(uffd, UFFDIO_REGISTER, &uffdio_register)) {
+		fprintf(stderr, "register failure\n");
+		exit(1);
+	}
 
 	expected_ioctls = uffd_test_ops->expected_ioctls;
 	if ((uffdio_register.ioctls & expected_ioctls) !=
-	    expected_ioctls)
+	    expected_ioctls) {
 		fprintf(stderr,
-			"unexpected missing ioctl for anon memory\n"),
-			exit(1);
+			"unexpected missing ioctl for anon memory\n");
+		exit(1);
+	}
 
 	if (uffdio_zeropage(uffd, 0)) {
-		if (my_bcmp(area_dst, zeropage, page_size))
-			fprintf(stderr, "zeropage is not zero\n"), exit(1);
+		if (my_bcmp(area_dst, zeropage, page_size)) {
+			fprintf(stderr, "zeropage is not zero\n");
+			exit(1);
+		}
 	}
 
 	close(uffd);
@@ -910,11 +1079,11 @@
 {
 	struct uffdio_register uffdio_register;
 	unsigned long expected_ioctls;
-	unsigned long userfaults;
 	pthread_t uffd_mon;
 	int err, features;
 	pid_t pid;
 	char c;
+	struct uffd_stats stats = { 0 };
 
 	printf("testing events (fork, remap, remove): ");
 	fflush(stdout);
@@ -931,39 +1100,51 @@
 	uffdio_register.range.start = (unsigned long) area_dst;
 	uffdio_register.range.len = nr_pages * page_size;
 	uffdio_register.mode = UFFDIO_REGISTER_MODE_MISSING;
-	if (ioctl(uffd, UFFDIO_REGISTER, &uffdio_register))
-		fprintf(stderr, "register failure\n"), exit(1);
+	if (test_uffdio_wp)
+		uffdio_register.mode |= UFFDIO_REGISTER_MODE_WP;
+	if (ioctl(uffd, UFFDIO_REGISTER, &uffdio_register)) {
+		fprintf(stderr, "register failure\n");
+		exit(1);
+	}
 
 	expected_ioctls = uffd_test_ops->expected_ioctls;
-	if ((uffdio_register.ioctls & expected_ioctls) !=
-	    expected_ioctls)
-		fprintf(stderr,
-			"unexpected missing ioctl for anon memory\n"),
-			exit(1);
+	if ((uffdio_register.ioctls & expected_ioctls) != expected_ioctls) {
+		fprintf(stderr, "unexpected missing ioctl for anon memory\n");
+		exit(1);
+	}
 
-	if (pthread_create(&uffd_mon, &attr, uffd_poll_thread, NULL))
-		perror("uffd_poll_thread create"), exit(1);
+	if (pthread_create(&uffd_mon, &attr, uffd_poll_thread, &stats)) {
+		perror("uffd_poll_thread create");
+		exit(1);
+	}
 
 	pid = fork();
-	if (pid < 0)
-		perror("fork"), exit(1);
+	if (pid < 0) {
+		perror("fork");
+		exit(1);
+	}
 
 	if (!pid)
 		return faulting_process(0);
 
 	waitpid(pid, &err, 0);
-	if (err)
-		fprintf(stderr, "faulting process failed\n"), exit(1);
+	if (err) {
+		fprintf(stderr, "faulting process failed\n");
+		exit(1);
+	}
 
-	if (write(pipefd[1], &c, sizeof(c)) != sizeof(c))
-		perror("pipe write"), exit(1);
-	if (pthread_join(uffd_mon, (void **)&userfaults))
+	if (write(pipefd[1], &c, sizeof(c)) != sizeof(c)) {
+		perror("pipe write");
+		exit(1);
+	}
+	if (pthread_join(uffd_mon, NULL))
 		return 1;
 
 	close(uffd);
-	printf("userfaults: %ld\n", userfaults);
 
-	return userfaults != nr_pages;
+	uffd_stats_report(&stats, 1);
+
+	return stats.missing_faults != nr_pages;
 }
 
 static int userfaultfd_sig_test(void)
@@ -975,6 +1156,7 @@
 	int err, features;
 	pid_t pid;
 	char c;
+	struct uffd_stats stats = { 0 };
 
 	printf("testing signal delivery: ");
 	fflush(stdout);
@@ -990,38 +1172,51 @@
 	uffdio_register.range.start = (unsigned long) area_dst;
 	uffdio_register.range.len = nr_pages * page_size;
 	uffdio_register.mode = UFFDIO_REGISTER_MODE_MISSING;
-	if (ioctl(uffd, UFFDIO_REGISTER, &uffdio_register))
-		fprintf(stderr, "register failure\n"), exit(1);
+	if (test_uffdio_wp)
+		uffdio_register.mode |= UFFDIO_REGISTER_MODE_WP;
+	if (ioctl(uffd, UFFDIO_REGISTER, &uffdio_register)) {
+		fprintf(stderr, "register failure\n");
+		exit(1);
+	}
 
 	expected_ioctls = uffd_test_ops->expected_ioctls;
-	if ((uffdio_register.ioctls & expected_ioctls) !=
-	    expected_ioctls)
-		fprintf(stderr,
-			"unexpected missing ioctl for anon memory\n"),
-			exit(1);
+	if ((uffdio_register.ioctls & expected_ioctls) != expected_ioctls) {
+		fprintf(stderr, "unexpected missing ioctl for anon memory\n");
+		exit(1);
+	}
 
-	if (faulting_process(1))
-		fprintf(stderr, "faulting process failed\n"), exit(1);
+	if (faulting_process(1)) {
+		fprintf(stderr, "faulting process failed\n");
+		exit(1);
+	}
 
 	if (uffd_test_ops->release_pages(area_dst))
 		return 1;
 
-	if (pthread_create(&uffd_mon, &attr, uffd_poll_thread, NULL))
-		perror("uffd_poll_thread create"), exit(1);
+	if (pthread_create(&uffd_mon, &attr, uffd_poll_thread, &stats)) {
+		perror("uffd_poll_thread create");
+		exit(1);
+	}
 
 	pid = fork();
-	if (pid < 0)
-		perror("fork"), exit(1);
+	if (pid < 0) {
+		perror("fork");
+		exit(1);
+	}
 
 	if (!pid)
 		exit(faulting_process(2));
 
 	waitpid(pid, &err, 0);
-	if (err)
-		fprintf(stderr, "faulting process failed\n"), exit(1);
+	if (err) {
+		fprintf(stderr, "faulting process failed\n");
+		exit(1);
+	}
 
-	if (write(pipefd[1], &c, sizeof(c)) != sizeof(c))
-		perror("pipe write"), exit(1);
+	if (write(pipefd[1], &c, sizeof(c)) != sizeof(c)) {
+		perror("pipe write");
+		exit(1);
+	}
 	if (pthread_join(uffd_mon, (void **)&userfaults))
 		return 1;
 
@@ -1032,6 +1227,7 @@
 	close(uffd);
 	return userfaults != 0;
 }
+
 static int userfaultfd_stress(void)
 {
 	void *area;
@@ -1040,7 +1236,7 @@
 	struct uffdio_register uffdio_register;
 	unsigned long cpu;
 	int err;
-	unsigned long userfaults[nr_cpus];
+	struct uffd_stats uffd_stats[nr_cpus];
 
 	uffd_test_ops->allocate_area((void **)&area_src);
 	if (!area_src)
@@ -1121,6 +1317,8 @@
 		uffdio_register.range.start = (unsigned long) area_dst;
 		uffdio_register.range.len = nr_pages * page_size;
 		uffdio_register.mode = UFFDIO_REGISTER_MODE_MISSING;
+		if (test_uffdio_wp)
+			uffdio_register.mode |= UFFDIO_REGISTER_MODE_WP;
 		if (ioctl(uffd, UFFDIO_REGISTER, &uffdio_register)) {
 			fprintf(stderr, "register failure\n");
 			return 1;
@@ -1169,10 +1367,17 @@
 		if (uffd_test_ops->release_pages(area_dst))
 			return 1;
 
+		uffd_stats_reset(uffd_stats, nr_cpus);
+
 		/* bounce pass */
-		if (stress(userfaults))
+		if (stress(uffd_stats))
 			return 1;
 
+		/* Clear all the write protections if there is any */
+		if (test_uffdio_wp)
+			wp_range(uffd, (unsigned long)area_dst,
+				 nr_pages * page_size, false);
+
 		/* unregister */
 		if (ioctl(uffd, UFFDIO_UNREGISTER, &uffdio_register.range)) {
 			fprintf(stderr, "unregister failure\n");
@@ -1211,10 +1416,7 @@
 		area_src_alias = area_dst_alias;
 		area_dst_alias = tmp_area;
 
-		printf("userfaults:");
-		for (cpu = 0; cpu < nr_cpus; cpu++)
-			printf(" %lu", userfaults[cpu]);
-		printf("\n");
+		uffd_stats_report(uffd_stats, nr_cpus);
 	}
 
 	if (err)
@@ -1254,6 +1456,8 @@
 	if (!strcmp(type, "anon")) {
 		test_type = TEST_ANON;
 		uffd_test_ops = &anon_uffd_test_ops;
+		/* Only enable write-protect test for anonymous test */
+		test_uffdio_wp = true;
 	} else if (!strcmp(type, "hugetlb")) {
 		test_type = TEST_HUGETLB;
 		uffd_test_ops = &hugetlb_uffd_test_ops;
@@ -1266,7 +1470,7 @@
 		test_type = TEST_SHMEM;
 		uffd_test_ops = &shmem_uffd_test_ops;
 	} else {
-		fprintf(stderr, "Unknown test type: %s\n", type), exit(1);
+		fprintf(stderr, "Unknown test type: %s\n", type); exit(1);
 	}
 
 	if (test_type == TEST_HUGETLB)
@@ -1274,12 +1478,15 @@
 	else
 		page_size = sysconf(_SC_PAGE_SIZE);
 
-	if (!page_size)
-		fprintf(stderr, "Unable to determine page size\n"),
-				exit(2);
+	if (!page_size) {
+		fprintf(stderr, "Unable to determine page size\n");
+		exit(2);
+	}
 	if ((unsigned long) area_count(NULL, 0) + sizeof(unsigned long long) * 2
-	    > page_size)
-		fprintf(stderr, "Impossible to run this test\n"), exit(2);
+	    > page_size) {
+		fprintf(stderr, "Impossible to run this test\n");
+		exit(2);
+	}
 }
 
 static void sigalrm(int sig)
@@ -1296,8 +1503,10 @@
 	if (argc < 4)
 		usage();
 
-	if (signal(SIGALRM, sigalrm) == SIG_ERR)
-		fprintf(stderr, "failed to arm SIGALRM"), exit(1);
+	if (signal(SIGALRM, sigalrm) == SIG_ERR) {
+		fprintf(stderr, "failed to arm SIGALRM");
+		exit(1);
+	}
 	alarm(ALARM_INTERVAL_SECS);
 
 	set_test_type(argv[1]);
diff --git a/tools/testing/selftests/vm/write_hugetlb_memory.sh b/tools/testing/selftests/vm/write_hugetlb_memory.sh
new file mode 100644
index 0000000..d3d0d10
--- /dev/null
+++ b/tools/testing/selftests/vm/write_hugetlb_memory.sh
@@ -0,0 +1,23 @@
+#!/bin/sh
+# SPDX-License-Identifier: GPL-2.0
+
+set -e
+
+size=$1
+populate=$2
+write=$3
+cgroup=$4
+path=$5
+method=$6
+private=$7
+want_sleep=$8
+reserve=$9
+
+echo "Putting task in cgroup '$cgroup'"
+echo $$ > /dev/cgroup/memory/"$cgroup"/cgroup.procs
+
+echo "Method is $method"
+
+set +e
+./write_to_hugetlbfs -p "$path" -s "$size" "$write" "$populate" -m "$method" \
+      "$private" "$want_sleep" "$reserve"
diff --git a/tools/testing/selftests/vm/write_to_hugetlbfs.c b/tools/testing/selftests/vm/write_to_hugetlbfs.c
new file mode 100644
index 0000000..6a2caba
--- /dev/null
+++ b/tools/testing/selftests/vm/write_to_hugetlbfs.c
@@ -0,0 +1,240 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * This program reserves and uses hugetlb memory, supporting a bunch of
+ * scenarios needed by the charged_reserved_hugetlb.sh test.
+ */
+
+#include <err.h>
+#include <errno.h>
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <unistd.h>
+#include <fcntl.h>
+#include <sys/types.h>
+#include <sys/shm.h>
+#include <sys/stat.h>
+#include <sys/mman.h>
+
+/* Global definitions. */
+enum method {
+	HUGETLBFS,
+	MMAP_MAP_HUGETLB,
+	SHM,
+	MAX_METHOD
+};
+
+
+/* Global variables. */
+static const char *self;
+static char *shmaddr;
+static int shmid;
+
+/*
+ * Show usage and exit.
+ */
+static void exit_usage(void)
+{
+	printf("Usage: %s -p <path to hugetlbfs file> -s <size to map> "
+	       "[-m <0=hugetlbfs | 1=mmap(MAP_HUGETLB)>] [-l] [-r] "
+	       "[-o] [-w] [-n]\n",
+	       self);
+	exit(EXIT_FAILURE);
+}
+
+void sig_handler(int signo)
+{
+	printf("Received %d.\n", signo);
+	if (signo == SIGINT) {
+		printf("Deleting the memory\n");
+		if (shmdt((const void *)shmaddr) != 0) {
+			perror("Detach failure");
+			shmctl(shmid, IPC_RMID, NULL);
+			exit(4);
+		}
+
+		shmctl(shmid, IPC_RMID, NULL);
+		printf("Done deleting the memory\n");
+	}
+	exit(2);
+}
+
+int main(int argc, char **argv)
+{
+	int fd = 0;
+	int key = 0;
+	int *ptr = NULL;
+	int c = 0;
+	int size = 0;
+	char path[256] = "";
+	enum method method = MAX_METHOD;
+	int want_sleep = 0, private = 0;
+	int populate = 0;
+	int write = 0;
+	int reserve = 1;
+
+	if (signal(SIGINT, sig_handler) == SIG_ERR)
+		err(1, "\ncan't catch SIGINT\n");
+
+	/* Parse command-line arguments. */
+	setvbuf(stdout, NULL, _IONBF, 0);
+	self = argv[0];
+
+	while ((c = getopt(argc, argv, "s:p:m:owlrn")) != -1) {
+		switch (c) {
+		case 's':
+			size = atoi(optarg);
+			break;
+		case 'p':
+			strncpy(path, optarg, sizeof(path));
+			break;
+		case 'm':
+			if (atoi(optarg) >= MAX_METHOD) {
+				errno = EINVAL;
+				perror("Invalid -m.");
+				exit_usage();
+			}
+			method = atoi(optarg);
+			break;
+		case 'o':
+			populate = 1;
+			break;
+		case 'w':
+			write = 1;
+			break;
+		case 'l':
+			want_sleep = 1;
+			break;
+		case 'r':
+		    private
+			= 1;
+			break;
+		case 'n':
+			reserve = 0;
+			break;
+		default:
+			errno = EINVAL;
+			perror("Invalid arg");
+			exit_usage();
+		}
+	}
+
+	if (strncmp(path, "", sizeof(path)) != 0) {
+		printf("Writing to this path: %s\n", path);
+	} else {
+		errno = EINVAL;
+		perror("path not found");
+		exit_usage();
+	}
+
+	if (size != 0) {
+		printf("Writing this size: %d\n", size);
+	} else {
+		errno = EINVAL;
+		perror("size not found");
+		exit_usage();
+	}
+
+	if (!populate)
+		printf("Not populating.\n");
+	else
+		printf("Populating.\n");
+
+	if (!write)
+		printf("Not writing to memory.\n");
+
+	if (method == MAX_METHOD) {
+		errno = EINVAL;
+		perror("-m Invalid");
+		exit_usage();
+	} else
+		printf("Using method=%d\n", method);
+
+	if (!private)
+		printf("Shared mapping.\n");
+	else
+		printf("Private mapping.\n");
+
+	if (!reserve)
+		printf("NO_RESERVE mapping.\n");
+	else
+		printf("RESERVE mapping.\n");
+
+	switch (method) {
+	case HUGETLBFS:
+		printf("Allocating using HUGETLBFS.\n");
+		fd = open(path, O_CREAT | O_RDWR, 0777);
+		if (fd == -1)
+			err(1, "Failed to open file.");
+
+		ptr = mmap(NULL, size, PROT_READ | PROT_WRITE,
+			   (private ? MAP_PRIVATE : MAP_SHARED) |
+				   (populate ? MAP_POPULATE : 0) |
+				   (reserve ? 0 : MAP_NORESERVE),
+			   fd, 0);
+
+		if (ptr == MAP_FAILED) {
+			close(fd);
+			err(1, "Error mapping the file");
+		}
+		break;
+	case MMAP_MAP_HUGETLB:
+		printf("Allocating using MAP_HUGETLB.\n");
+		ptr = mmap(NULL, size, PROT_READ | PROT_WRITE,
+			   (private ? (MAP_PRIVATE | MAP_ANONYMOUS) :
+				      MAP_SHARED) |
+				   MAP_HUGETLB | (populate ? MAP_POPULATE : 0) |
+				   (reserve ? 0 : MAP_NORESERVE),
+			   -1, 0);
+
+		if (ptr == MAP_FAILED)
+			err(1, "mmap");
+
+		printf("Returned address is %p\n", ptr);
+		break;
+	case SHM:
+		printf("Allocating using SHM.\n");
+		shmid = shmget(key, size,
+			       SHM_HUGETLB | IPC_CREAT | SHM_R | SHM_W);
+		if (shmid < 0) {
+			shmid = shmget(++key, size,
+				       SHM_HUGETLB | IPC_CREAT | SHM_R | SHM_W);
+			if (shmid < 0)
+				err(1, "shmget");
+		}
+		printf("shmid: 0x%x, shmget key:%d\n", shmid, key);
+
+		ptr = shmat(shmid, NULL, 0);
+		if (ptr == (int *)-1) {
+			perror("Shared memory attach failure");
+			shmctl(shmid, IPC_RMID, NULL);
+			exit(2);
+		}
+		printf("shmaddr: %p\n", ptr);
+
+		break;
+	default:
+		errno = EINVAL;
+		err(1, "Invalid method.");
+	}
+
+	if (write) {
+		printf("Writing to memory.\n");
+		memset(ptr, 1, size);
+	}
+
+	if (want_sleep) {
+		/* Signal to caller that we're done. */
+		printf("DONE\n");
+
+		/* Hold memory until external kill signal is delivered. */
+		while (1)
+			sleep(100);
+	}
+
+	if (method == HUGETLBFS)
+		close(fd);
+
+	return 0;
+}