Skip to content

Commit

Permalink
Only split by-ref/by-move futures for async closures
Browse files Browse the repository at this point in the history
  • Loading branch information
compiler-errors committed Mar 19, 2024
1 parent e760daa commit 05116c5
Show file tree
Hide file tree
Showing 33 changed files with 119 additions and 432 deletions.
2 changes: 1 addition & 1 deletion compiler/rustc_borrowck/src/type_check/input_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
self.tcx(),
ty::CoroutineArgsParts {
parent_args: args.parent_args(),
kind_ty: Ty::from_closure_kind(self.tcx(), args.kind()),
kind_ty: Ty::from_coroutine_closure_kind(self.tcx(), args.kind()),
return_ty: user_provided_sig.output(),
tupled_upvars_ty,
// For async closures, none of these can be annotated, so just fill
Expand Down
14 changes: 9 additions & 5 deletions compiler/rustc_hir_typeck/src/callee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,20 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
kind: TypeVariableOriginKind::TypeInference,
span: callee_expr.span,
});
// We may actually receive a coroutine back whose kind is different
// from the closure that this dispatched from. This is because when
// we have no captures, we automatically implement `FnOnce`. This
// impl forces the closure kind to `FnOnce` i.e. `u8`.
let kind_ty = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::TypeInference,
span: callee_expr.span,
});
let call_sig = self.tcx.mk_fn_sig(
[coroutine_closure_sig.tupled_inputs_ty],
coroutine_closure_sig.to_coroutine(
self.tcx,
closure_args.parent_args(),
// Inherit the kind ty of the closure, since we're calling this
// coroutine with the most relaxed `AsyncFn*` trait that we can.
// We don't necessarily need to do this here, but it saves us
// computing one more infer var that will get constrained later.
closure_args.kind_ty(),
kind_ty,
self.tcx.coroutine_for_closure(def_id),
tupled_upvars_ty,
),
Expand Down
6 changes: 5 additions & 1 deletion compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
},
);

let coroutine_kind_ty = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::ClosureSynthetic,
span: expr_span,
});
let coroutine_upvars_ty = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::ClosureSynthetic,
span: expr_span,
Expand All @@ -279,7 +283,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
sig.to_coroutine(
tcx,
parent_args,
closure_kind_ty,
coroutine_kind_ty,
tcx.coroutine_for_closure(expr_def_id),
coroutine_upvars_ty,
)
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_hir_typeck/src/upvar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
self.demand_eqtype(
span,
coroutine_args.as_coroutine().kind_ty(),
Ty::from_closure_kind(self.tcx, closure_kind),
Ty::from_coroutine_closure_kind(self.tcx, closure_kind),
);
}

Expand Down
12 changes: 0 additions & 12 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,6 @@ pub struct CoroutineInfo<'tcx> {
/// using `run_passes`.
pub by_move_body: Option<Body<'tcx>>,

/// The body of the coroutine, modified to take its upvars by mutable ref rather than by
/// immutable ref.
///
/// FIXME(async_closures): This is literally the same body as the parent body. Find a better
/// way to represent the by-mut signature (or cap the closure-kind of the coroutine).
pub by_mut_body: Option<Body<'tcx>>,

/// The layout of a coroutine. This field is populated after the state transform pass.
pub coroutine_layout: Option<CoroutineLayout<'tcx>>,

Expand All @@ -305,7 +298,6 @@ impl<'tcx> CoroutineInfo<'tcx> {
yield_ty: Some(yield_ty),
resume_ty: Some(resume_ty),
by_move_body: None,
by_mut_body: None,
coroutine_drop: None,
coroutine_layout: None,
}
Expand Down Expand Up @@ -628,10 +620,6 @@ impl<'tcx> Body<'tcx> {
self.coroutine.as_ref()?.by_move_body.as_ref()
}

pub fn coroutine_by_mut_body(&self) -> Option<&Body<'tcx>> {
self.coroutine.as_ref()?.by_mut_body.as_ref()
}

#[inline]
pub fn coroutine_kind(&self) -> Option<CoroutineKind> {
self.coroutine.as_ref().map(|coroutine| coroutine.coroutine_kind)
Expand Down
6 changes: 4 additions & 2 deletions compiler/rustc_middle/src/mir/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,10 @@ macro_rules! make_mir_visitor {
ty::InstanceDef::Virtual(_def_id, _) |
ty::InstanceDef::ThreadLocalShim(_def_id) |
ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } |
ty::InstanceDef::CoroutineKindShim { coroutine_def_id: _def_id, target_kind: _ } |
ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id: _def_id,
} |
ty::InstanceDef::CoroutineKindShim { coroutine_def_id: _def_id } |
ty::InstanceDef::DropGlue(_def_id, None) => {}

ty::InstanceDef::FnPtrShim(_def_id, ty) |
Expand Down
18 changes: 5 additions & 13 deletions compiler/rustc_middle/src/ty/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,20 @@ pub enum InstanceDef<'tcx> {
/// and dispatch to the `FnMut::call_mut` instance for the closure.
ClosureOnceShim { call_once: DefId, track_caller: bool },

/// `<[FnMut/Fn coroutine-closure] as FnOnce>::call_once` or
/// `<[Fn coroutine-closure] as FnMut>::call_mut`.
/// `<[FnMut/Fn coroutine-closure] as FnOnce>::call_once`
///
/// The body generated here differs significantly from the `ClosureOnceShim`,
/// since we need to generate a distinct coroutine type that will move the
/// closure's upvars *out* of the closure.
ConstructCoroutineInClosureShim {
coroutine_closure_def_id: DefId,
target_kind: ty::ClosureKind,
},
ConstructCoroutineInClosureShim { coroutine_closure_def_id: DefId },

/// `<[coroutine] as Future>::poll`, but for coroutines produced when `AsyncFnOnce`
/// is called on a coroutine-closure whose closure kind greater than `FnOnce`, or
/// similarly for `AsyncFnMut`.
///
/// This will select the body that is produced by the `ByMoveBody` transform, and thus
/// take and use all of its upvars by-move rather than by-ref.
CoroutineKindShim { coroutine_def_id: DefId, target_kind: ty::ClosureKind },
CoroutineKindShim { coroutine_def_id: DefId },

/// Compiler-generated accessor for thread locals which returns a reference to the thread local
/// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking
Expand Down Expand Up @@ -192,9 +188,8 @@ impl<'tcx> InstanceDef<'tcx> {
| InstanceDef::ClosureOnceShim { call_once: def_id, track_caller: _ }
| ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id: def_id,
target_kind: _,
}
| ty::InstanceDef::CoroutineKindShim { coroutine_def_id: def_id, target_kind: _ }
| ty::InstanceDef::CoroutineKindShim { coroutine_def_id: def_id }
| InstanceDef::DropGlue(def_id, _)
| InstanceDef::CloneShim(def_id, _)
| InstanceDef::FnPtrAddrShim(def_id, _) => def_id,
Expand Down Expand Up @@ -651,10 +646,7 @@ impl<'tcx> Instance<'tcx> {
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
} else {
Some(Instance {
def: ty::InstanceDef::CoroutineKindShim {
coroutine_def_id,
target_kind: args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(),
},
def: ty::InstanceDef::CoroutineKindShim { coroutine_def_id },
args,
})
}
Expand Down
17 changes: 16 additions & 1 deletion compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ impl<'tcx> CoroutineClosureSignature<'tcx> {
self.to_coroutine(
tcx,
parent_args,
Ty::from_closure_kind(tcx, goal_kind),
Ty::from_coroutine_closure_kind(tcx, goal_kind),
coroutine_def_id,
tupled_upvars_ty,
)
Expand Down Expand Up @@ -2456,6 +2456,21 @@ impl<'tcx> Ty<'tcx> {
}
}

/// Like [`Ty::to_opt_closure_kind`], but it caps the "maximum" closure kind
/// to `FnMut`. This is because although we have three capability states,
/// `AsyncFn`/`AsyncFnMut`/`AsyncFnOnce`, we only need to distinguish two coroutine
/// bodies: by-ref and by-value.
///
/// This method should be used when constructing a `Coroutine` out of a
/// `CoroutineClosure`, when the `Coroutine`'s `kind` field is being populated
/// directly from the `CoroutineClosure`'s `kind`.
pub fn from_coroutine_closure_kind(tcx: TyCtxt<'tcx>, kind: ty::ClosureKind) -> Ty<'tcx> {
match kind {
ty::ClosureKind::Fn | ty::ClosureKind::FnMut => tcx.types.i16,
ty::ClosureKind::FnOnce => tcx.types.i32,
}
}

/// Fast path helper for testing if a type is `Sized`.
///
/// Returning true means the type is known to be sized. Returning
Expand Down
35 changes: 0 additions & 35 deletions compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,45 +67,10 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
by_move_body.source = mir::MirSource {
instance: InstanceDef::CoroutineKindShim {
coroutine_def_id: coroutine_def_id.to_def_id(),
target_kind: ty::ClosureKind::FnOnce,
},
promoted: None,
};
body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body);

// If this is coming from an `AsyncFn` coroutine-closure, we must also create a by-mut body.
// This is actually just a copy of the by-ref body, but with a different self type.
// FIXME(async_closures): We could probably unify this with the by-ref body somehow.
if coroutine_kind == ty::ClosureKind::Fn {
let by_mut_coroutine_ty = Ty::new_coroutine(
tcx,
coroutine_def_id.to_def_id(),
ty::CoroutineArgs::new(
tcx,
ty::CoroutineArgsParts {
parent_args: args.as_coroutine().parent_args(),
kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnMut),
resume_ty: args.as_coroutine().resume_ty(),
yield_ty: args.as_coroutine().yield_ty(),
return_ty: args.as_coroutine().return_ty(),
witness: args.as_coroutine().witness(),
tupled_upvars_ty: args.as_coroutine().tupled_upvars_ty(),
},
)
.args,
);
let mut by_mut_body = body.clone();
by_mut_body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty = by_mut_coroutine_ty;
dump_mir(tcx, false, "coroutine_by_mut", &0, &by_mut_body, |_, _| Ok(()));
by_mut_body.source = mir::MirSource {
instance: InstanceDef::CoroutineKindShim {
coroutine_def_id: coroutine_def_id.to_def_id(),
target_kind: ty::ClosureKind::FnMut,
},
promoted: None,
};
body.coroutine.as_mut().unwrap().by_mut_body = Some(by_mut_body);
}
}
}

Expand Down
3 changes: 0 additions & 3 deletions compiler/rustc_mir_transform/src/pass_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,6 @@ fn run_passes_inner<'tcx>(
if let Some(by_move_body) = coroutine.by_move_body.as_mut() {
run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each);
}
if let Some(by_mut_body) = coroutine.by_mut_body.as_mut() {
run_passes_inner(tcx, by_mut_body, passes, phase_change, validate_each);
}
}
}

Expand Down
98 changes: 12 additions & 86 deletions compiler/rustc_mir_transform/src/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use rustc_hir::def_id::DefId;
use rustc_hir::lang_items::LangItem;
use rustc_middle::mir::*;
use rustc_middle::query::Providers;
use rustc_middle::ty::GenericArgs;
use rustc_middle::ty::{self, CoroutineArgs, EarlyBinder, Ty, TyCtxt};
use rustc_middle::ty::{GenericArgs, CAPTURE_STRUCT_LOCAL};
use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};

use rustc_index::{Idx, IndexVec};
Expand Down Expand Up @@ -70,39 +70,13 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut))
}

ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
target_kind,
} => match target_kind {
ty::ClosureKind::Fn => unreachable!("shouldn't be building shim for Fn"),
ty::ClosureKind::FnMut => {
// No need to optimize the body, it has already been optimized
// since we steal it from the `AsyncFn::call` body and just fix
// the return type.
return build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id);
}
ty::ClosureKind::FnOnce => {
build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id)
}
},
ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id } => {
build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id)
}

ty::InstanceDef::CoroutineKindShim { coroutine_def_id, target_kind } => match target_kind {
ty::ClosureKind::Fn => unreachable!(),
ty::ClosureKind::FnMut => {
return tcx
.optimized_mir(coroutine_def_id)
.coroutine_by_mut_body()
.unwrap()
.clone();
}
ty::ClosureKind::FnOnce => {
return tcx
.optimized_mir(coroutine_def_id)
.coroutine_by_move_body()
.unwrap()
.clone();
}
},
ty::InstanceDef::CoroutineKindShim { coroutine_def_id } => {
return tcx.optimized_mir(coroutine_def_id).coroutine_by_move_body().unwrap().clone();
}

ty::InstanceDef::DropGlue(def_id, ty) => {
// FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
Expand All @@ -123,21 +97,11 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
let body = if id_args.as_coroutine().kind_ty() == args.as_coroutine().kind_ty() {
coroutine_body.coroutine_drop().unwrap()
} else {
match args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() {
ty::ClosureKind::Fn => {
unreachable!()
}
ty::ClosureKind::FnMut => coroutine_body
.coroutine_by_mut_body()
.unwrap()
.coroutine_drop()
.unwrap(),
ty::ClosureKind::FnOnce => coroutine_body
.coroutine_by_move_body()
.unwrap()
.coroutine_drop()
.unwrap(),
}
assert_eq!(
args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(),
ty::ClosureKind::FnOnce
);
coroutine_body.coroutine_by_move_body().unwrap().coroutine_drop().unwrap()
};

let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args);
Expand Down Expand Up @@ -1112,7 +1076,6 @@ fn build_construct_coroutine_by_move_shim<'tcx>(

let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
target_kind: ty::ClosureKind::FnOnce,
});

let body =
Expand All @@ -1121,40 +1084,3 @@ fn build_construct_coroutine_by_move_shim<'tcx>(

body
}

fn build_construct_coroutine_by_mut_shim<'tcx>(
tcx: TyCtxt<'tcx>,
coroutine_closure_def_id: DefId,
) -> Body<'tcx> {
let mut body = tcx.optimized_mir(coroutine_closure_def_id).clone();
let coroutine_closure_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
let ty::CoroutineClosure(_, args) = *coroutine_closure_ty.kind() else {
bug!();
};
let args = args.as_coroutine_closure();

body.local_decls[RETURN_PLACE].ty =
tcx.instantiate_bound_regions_with_erased(args.coroutine_closure_sig().map_bound(|sig| {
sig.to_coroutine_given_kind_and_upvars(
tcx,
args.parent_args(),
tcx.coroutine_for_closure(coroutine_closure_def_id),
ty::ClosureKind::FnMut,
tcx.lifetimes.re_erased,
args.tupled_upvars_ty(),
args.coroutine_captures_by_ref_ty(),
)
}));
body.local_decls[CAPTURE_STRUCT_LOCAL].ty =
Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_closure_ty);

body.source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
target_kind: ty::ClosureKind::FnMut,
});

body.pass_count = 0;
dump_mir(tcx, false, "coroutine_closure_by_mut", &0, &body, |_, _| Ok(()));

body
}
3 changes: 1 addition & 2 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,8 @@ symbols! {
Break,
C,
CStr,
CallFuture,
CallMutFuture,
CallOnceFuture,
CallRefFuture,
Capture,
Center,
Cleanup,
Expand Down
Loading

0 comments on commit 05116c5

Please sign in to comment.