diff --git a/kernel/cgroup/cgroup.c b/kernel/cgroup/cgroup.c
index 16fe1c6cad358d3fdd3d1ddde88be1c06002ce35..502769b2683cd689cf4cf35b43a44809e8a6bb4e 100644
--- a/kernel/cgroup/cgroup.c
+++ b/kernel/cgroup/cgroup.c
@@ -5900,17 +5900,20 @@ static struct cgroup *cgroup_get_from_file(struct file *f)
 
 /**
  * cgroup_can_fork - called on a new task before the process is exposed
- * @child: the task in question.
+ * @child: the child process
  *
- * This calls the subsystem can_fork() callbacks. If the can_fork() callback
- * returns an error, the fork aborts with that error code. This allows for
- * a cgroup subsystem to conditionally allow or deny new forks.
+ * This calls the subsystem can_fork() callbacks. If the cgroup_can_fork()
+ * callback returns an error, the fork aborts with that error code. This
+ * allows for a cgroup subsystem to conditionally allow or deny new forks.
  */
 int cgroup_can_fork(struct task_struct *child)
+	__acquires(&cgroup_threadgroup_rwsem) __releases(&cgroup_threadgroup_rwsem)
 {
 	struct cgroup_subsys *ss;
 	int i, j, ret;
 
+	cgroup_threadgroup_change_begin(current);
+
 	do_each_subsys_mask(ss, i, have_canfork_callback) {
 		ret = ss->can_fork(child);
 		if (ret)
@@ -5927,17 +5930,20 @@ int cgroup_can_fork(struct task_struct *child)
 			ss->cancel_fork(child);
 	}
 
+	cgroup_threadgroup_change_end(current);
+
 	return ret;
 }
 
 /**
- * cgroup_cancel_fork - called if a fork failed after cgroup_can_fork()
- * @child: the task in question
- *
- * This calls the cancel_fork() callbacks if a fork failed *after*
- * cgroup_can_fork() succeded.
- */
+  * cgroup_cancel_fork - called if a fork failed after cgroup_can_fork()
+  * @child: the child process
+  *
+  * This calls the cancel_fork() callbacks if a fork failed *after*
+  * cgroup_can_fork() succeded.
+  */
 void cgroup_cancel_fork(struct task_struct *child)
+	__releases(&cgroup_threadgroup_rwsem)
 {
 	struct cgroup_subsys *ss;
 	int i;
@@ -5945,19 +5951,19 @@ void cgroup_cancel_fork(struct task_struct *child)
 	for_each_subsys(ss, i)
 		if (ss->cancel_fork)
 			ss->cancel_fork(child);
+
+	cgroup_threadgroup_change_end(current);
 }
 
 /**
- * cgroup_post_fork - called on a new task after adding it to the task list
- * @child: the task in question
- *
- * Adds the task to the list running through its css_set if necessary and
- * call the subsystem fork() callbacks.  Has to be after the task is
- * visible on the task list in case we race with the first call to
- * cgroup_task_iter_start() - to guarantee that the new task ends up on its
- * list.
+ * cgroup_post_fork - finalize cgroup setup for the child process
+ * @child: the child process
+ *
+ * Attach the child process to its css_set calling the subsystem fork()
+ * callbacks.
  */
 void cgroup_post_fork(struct task_struct *child)
+	__releases(&cgroup_threadgroup_rwsem)
 {
 	struct cgroup_subsys *ss;
 	struct css_set *cset;
@@ -6003,6 +6009,8 @@ void cgroup_post_fork(struct task_struct *child)
 	do_each_subsys_mask(ss, i, have_fork_callback) {
 		ss->fork(child);
 	} while_each_subsys_mask();
+
+	cgroup_threadgroup_change_end(current);
 }
 
 /**
diff --git a/kernel/fork.c b/kernel/fork.c
index 60a1295f4384363ae9b35e589c02e52fbf914a26..9245b6e53f550f122aa7652532b7bfb04a056117 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -2174,7 +2174,6 @@ static __latent_entropy struct task_struct *copy_process(
 	INIT_LIST_HEAD(&p->thread_group);
 	p->task_works = NULL;
 
-	cgroup_threadgroup_change_begin(current);
 	/*
 	 * Ensure that the cgroup subsystem policies allow the new process to be
 	 * forked. It should be noted the the new process's css_set can be changed
@@ -2183,7 +2182,7 @@ static __latent_entropy struct task_struct *copy_process(
 	 */
 	retval = cgroup_can_fork(p);
 	if (retval)
-		goto bad_fork_cgroup_threadgroup_change_end;
+		goto bad_fork_put_pidfd;
 
 	/*
 	 * From this point on we must avoid any synchronous user-space
@@ -2289,7 +2288,6 @@ static __latent_entropy struct task_struct *copy_process(
 
 	proc_fork_connector(p);
 	cgroup_post_fork(p);
-	cgroup_threadgroup_change_end(current);
 	perf_event_fork(p);
 
 	trace_task_newtask(p, clone_flags);
@@ -2301,8 +2299,6 @@ static __latent_entropy struct task_struct *copy_process(
 	spin_unlock(&current->sighand->siglock);
 	write_unlock_irq(&tasklist_lock);
 	cgroup_cancel_fork(p);
-bad_fork_cgroup_threadgroup_change_end:
-	cgroup_threadgroup_change_end(current);
 bad_fork_put_pidfd:
 	if (clone_flags & CLONE_PIDFD) {
 		fput(pidfile);