Skip to content

Commit

Permalink
Merge pull request swiftlang#64780 from xedin/fix-default-args-with-v…
Browse files Browse the repository at this point in the history
…ariadic-generics

[CSApply] Teach `coerceCallArguments` about variadic generics
  • Loading branch information
xedin authored Mar 31, 2023
2 parents 66f4d5b + 00fbdc7 commit 7064e18
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 8 deletions.
95 changes: 87 additions & 8 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5490,7 +5490,7 @@ Solution::resolveLocatorToDecl(ConstraintLocator *locator) const {
/// index. This looks through inheritance for inherited default args.
static ConcreteDeclRef getDefaultArgOwner(ConcreteDeclRef owner,
unsigned index) {
auto *param = getParameterAt(owner.getDecl(), index);
auto *param = getParameterAt(owner, index);
assert(param);
if (param->getDefaultArgumentKind() == DefaultArgumentKind::Inherited) {
return getDefaultArgOwner(owner.getOverriddenDecl(), index);
Expand Down Expand Up @@ -5851,6 +5851,65 @@ static void applyContextualClosureFlags(
}
}

// For variadic generic declarations we need to compute a substituted
// version of bindings because all of the packs are exploaded in the
// substituted function type.
//
// \code
// func fn<each T>(_: repeat each T) {}
//
// fn("", 42)
// \endcode
//
// The type of `fn` in the call is `(String, Int) -> Void` but bindings
// have only one parameter at index `0` with two argument positions: 0, 1.
static bool shouldSubstituteParameterBindings(ConcreteDeclRef callee) {
auto subst = callee.getSubstitutions();
if (subst.empty())
return false;

auto sig = subst.getGenericSignature();
return llvm::any_of(
sig.getGenericParams(),
[&](const GenericTypeParamType *GP) { return GP->isParameterPack(); });
}

/// Compute parameter binding substitutions by exploding pack expansions
/// into multiple bindings (if they matched more than one argument) and
/// ignoring empty ones.
static void computeParameterBindingsSubstitutions(
ConcreteDeclRef callee, ArrayRef<AnyFunctionType::Param> params,
ArrayRef<ParamBinding> origBindings,
SmallVectorImpl<ParamBinding> &substitutedBindings) {
for (unsigned bindingIdx = 0, numBindings = origBindings.size();
bindingIdx != numBindings; ++bindingIdx) {
if (origBindings[bindingIdx].size() > 1) {
const auto &param = params[substitutedBindings.size()];
if (!param.isVariadic()) {
#ifndef NDEBUG
auto *PD = getParameterAt(callee.getDecl(), bindingIdx);
assert(PD && PD->getInterfaceType()->is<PackExpansionType>());
#endif
// Explode binding set to match substituted function parameters.
for (auto argIdx : origBindings[bindingIdx])
substitutedBindings.push_back({argIdx});
continue;
}
}

const auto &bindings = origBindings[bindingIdx];
if (bindings.size() == 0) {
auto *PD = getParameterAt(callee.getDecl(), bindingIdx);
// Skip pack expansions with no arguments because they are not
// present in the substituted function type.
if (PD->getInterfaceType()->is<PackExpansionType>())
continue;
}

substitutedBindings.push_back(bindings);
}
}

ArgumentList *ExprRewriter::coerceCallArguments(
ArgumentList *args, AnyFunctionType *funcType, ConcreteDeclRef callee,
ApplyExpr *apply, ConstraintLocatorBuilder locator,
Expand Down Expand Up @@ -5904,9 +5963,18 @@ ArgumentList *ExprRewriter::coerceCallArguments(
assert(solution.argumentMatchingChoices.count(locatorPtr) == 1);
auto parameterBindings = solution.argumentMatchingChoices.find(locatorPtr)
->second.parameterBindings;
bool shouldSubstituteBindings = shouldSubstituteParameterBindings(callee);

SmallVector<ParamBinding, 4> substitutedBindings;
if (shouldSubstituteBindings) {
computeParameterBindingsSubstitutions(callee, params, parameterBindings,
substitutedBindings);
} else {
substitutedBindings = parameterBindings;
}

SmallVector<Argument, 4> newArgs;
for (unsigned paramIdx = 0, numParams = parameterBindings.size();
for (unsigned paramIdx = 0, numParams = substitutedBindings.size();
paramIdx != numParams; ++paramIdx) {
// Extract the parameter.
const auto &param = params[paramIdx];
Expand All @@ -5920,7 +5988,7 @@ ArgumentList *ExprRewriter::coerceCallArguments(

// The first argument of this vararg parameter may have had a label;
// save its location.
auto &varargIndices = parameterBindings[paramIdx];
auto &varargIndices = substitutedBindings[paramIdx];
SourceLoc labelLoc;
if (!varargIndices.empty())
labelLoc = args->getLabelLoc(varargIndices[0]);
Expand Down Expand Up @@ -5969,11 +6037,22 @@ ArgumentList *ExprRewriter::coerceCallArguments(
}

// Handle default arguments.
if (parameterBindings[paramIdx].empty()) {
if (substitutedBindings[paramIdx].empty()) {
auto paramIdxForDefault = paramIdx;
// If bindings were substituted we need to find "original"
// (or contextless) parameter index for the default argument.
if (shouldSubstituteBindings) {
auto *paramList = getParameterList(callee.getDecl());
assert(paramList);
paramIdxForDefault =
paramList->getOrigParamIndex(callee.getSubstitutions(), paramIdx);
}

auto owner = getDefaultArgOwner(callee, paramIdx);
auto paramTy = param.getParameterType();
auto *defArg = new (ctx) DefaultArgumentExpr(
owner, paramIdx, args->getStartLoc(), paramTy, dc);
owner, paramIdxForDefault, args->getStartLoc(), paramTy, dc);

cs.cacheType(defArg);
newArgs.emplace_back(SourceLoc(), param.getLabel(), defArg);
continue;
Expand All @@ -5982,8 +6061,8 @@ ArgumentList *ExprRewriter::coerceCallArguments(
// Otherwise, we have a plain old ordinary argument.

// Extract the argument used to initialize this parameter.
assert(parameterBindings[paramIdx].size() == 1);
unsigned argIdx = parameterBindings[paramIdx].front();
assert(substitutedBindings[paramIdx].size() == 1);
unsigned argIdx = substitutedBindings[paramIdx].front();
auto arg = args->get(argIdx);
auto *argExpr = arg.getExpr();
auto argType = cs.getType(argExpr);
Expand Down Expand Up @@ -6027,7 +6106,7 @@ ArgumentList *ExprRewriter::coerceCallArguments(
};

if (paramInfo.hasExternalPropertyWrapper(paramIdx)) {
auto *paramDecl = getParameterAt(callee.getDecl(), paramIdx);
auto *paramDecl = getParameterAt(callee, paramIdx);
assert(paramDecl);

auto appliedWrapper = appliedPropertyWrappers[appliedWrapperIndex++];
Expand Down
35 changes: 35 additions & 0 deletions test/Constraints/pack-expansion-expressions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,38 @@ do {
return G<repeat each T>() // Ok
}
}

// Make sure that in-exact matches (that require any sort of conversion or load) on arguments are handled correctly.
do {
var v: Float = 42 // expected-warning {{variable 'v' was never mutated; consider changing to 'let' constant}}

func testOpt<each T>(x: Int?, _: repeat each T) {}
testOpt(x: 42, "", v) // Load + Optional promotion

func testLoad<each T, each U>(t: repeat each T, u: repeat each U) {}
testLoad(t: "", v) // Load + default
testLoad(t: "", v, u: v, 0.0) // Two loads

func testDefaultWithExtra<each T, each U>(t: repeat each T, u: repeat each U, extra: Int?) {}
testDefaultWithExtra(t: "", v, extra: 42)

func defaults1<each T>(x: Int? = nil, _: repeat each T) {}
defaults1("", 3.14) // Ok

func defaults2<each T>(_: repeat each T, x: Int? = nil) {}
defaults2("", 3.14) // Ok

func defaults3<each T, each U>(t: repeat each T, u: repeat each U, extra: Int? = nil) {}
defaults3(t: "", 3.14) // Ok
defaults3(t: "", 3.14, u: 0, v) // Ok
defaults3(t: "", 3.14, u: 0, v, extra: 42) // Ok

struct Defaulted<each T> {
init(t: repeat each T, extra: Int? = nil) {}
init<each U>(t: repeat each T, u: repeat each U, other: Int? = nil) {}
}

_ = Defaulted(t: "a", 0, 1.0) // Ok
_ = Defaulted(t: "b", 0) // Ok
_ = Defaulted(t: "c", 1.0, u: "d", 0) // Ok
}

0 comments on commit 7064e18

Please sign in to comment.