Skip to content

Commit

Permalink
Merge pull request apache#10862: [BEAM-9320] Add AlwaysFetched annota…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
reuvenlax committed Feb 16, 2020
1 parent 0bc4b3f commit 52fea0e
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ public TimeDomain timeDomain(DoFn<InputT, OutputT> doFn) {
}

@Override
public State state(String stateId) {
public State state(String stateId, boolean alwaysFetched) {
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(
}

@Override
public State state(String stateId) {
public State state(String stateId, boolean alwaysFetched) {
throw new UnsupportedOperationException(
"Access to state not supported in Splittable DoFn");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.state.ReadableState;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.TimeDomain;
Expand Down Expand Up @@ -374,7 +375,7 @@ public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT>
}

@Override
public State state(String stateId) {
public State state(String stateId, boolean alwaysFetched) {
throw new UnsupportedOperationException(
"Cannot access state outside of @ProcessElement and @OnTimer methods.");
}
Expand Down Expand Up @@ -511,7 +512,7 @@ public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT>
}

@Override
public State state(String stateId) {
public State state(String stateId, boolean alwaysFetched) {
throw new UnsupportedOperationException(
"Cannot access state outside of @ProcessElement and @OnTimer methods.");
}
Expand Down Expand Up @@ -745,13 +746,19 @@ public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT>
}

@Override
public State state(String stateId) {
public State state(String stateId, boolean alwaysFetched) {
try {
StateSpec<?> spec =
(StateSpec<?>) signature.stateDeclarations().get(stateId).field().get(fn);
return stepContext
.stateInternals()
.state(getNamespace(), StateTags.tagForSpec(stateId, (StateSpec) spec));
State state =
stepContext
.stateInternals()
.state(getNamespace(), StateTags.tagForSpec(stateId, (StateSpec) spec));
if (alwaysFetched) {
return (State) ((ReadableState) state).readLater();
} else {
return state;
}
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -937,13 +944,19 @@ public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT>
}

@Override
public State state(String stateId) {
public State state(String stateId, boolean alwaysFetched) {
try {
StateSpec<?> spec =
(StateSpec<?>) signature.stateDeclarations().get(stateId).field().get(fn);
return stepContext
.stateInternals()
.state(getNamespace(), StateTags.tagForSpec(stateId, (StateSpec) spec));
State state =
stepContext
.stateInternals()
.state(getNamespace(), StateTags.tagForSpec(stateId, (StateSpec) spec));
if (alwaysFetched) {
return (State) ((ReadableState) state).readLater();
} else {
return state;
}
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,43 @@ public interface MultiOutputReceiver {
String value();
}

/////////////////////////////////////////////////////////////////////////////

/**
* Annotation for declaring that a state parameter is always fetched.
*
* <p>A DoFn might not fetch a state value on every element, and for that reason runners may
* choose to defer fetching state until read() is called. Annotating a state argument with this
* parameter provides a hint to the runner that the state is always fetched. This may cause the
* runner to prefetch all the state before calling the processElement or processTimer method,
* improving performance. This is a performance-only hint - it does not change semantics. See the
* following code for an example:
*
* <pre><code>{@literal new DoFn<KV<Key, Foo>, Baz>()} {
*
* {@literal @StateId("my-state-id")}
* {@literal private final StateSpec<ValueState<MyState>>} myStateSpec =
* StateSpecs.value(new MyStateCoder());
*
* {@literal @ProcessElement}
* public void processElement(
* {@literal @Element InputT element},
* {@literal @AlwaysFetched @StateId("my-state-id") ValueState<MyState> myState}) {
* myState.read();
* myState.write(...);
* }
* }
* </code></pre>
*
* <p>This can only be used on state objects that implement {@link
* org.apache.beam.sdk.state.ReadableState}.
*/
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD, ElementType.PARAMETER})
@Experimental(Kind.STATE)
public @interface AlwaysFetched {}

/**
* Annotation for declaring and dereferencing timers.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ public Object restriction() {
}

@Override
public org.apache.beam.sdk.state.State state(String stateId) {
public org.apache.beam.sdk.state.State state(String stateId, boolean alwaysFetched) {
throw new UnsupportedOperationException("DoFnTester doesn't support state yet");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -856,8 +856,10 @@ public StackManipulation dispatch(RestrictionTrackerParameter p) {
public StackManipulation dispatch(StateParameter p) {
return new StackManipulation.Compound(
new TextConstant(p.referent().id()),
IntegerConstant.forValue(p.alwaysFetched()),
MethodInvocation.invoke(
getExtraContextFactoryMethodDescription(STATE_PARAMETER_METHOD, String.class)),
getExtraContextFactoryMethodDescription(
STATE_PARAMETER_METHOD, String.class, boolean.class)),
TypeCasting.to(
new TypeDescription.ForLoadedType(p.referent().stateType().getRawType())));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ interface ArgumentProvider<InputT, OutputT> {
RestrictionTracker<?, ?> restrictionTracker();

/** Returns the state cell for the given {@link StateId}. */
State state(String stateId);
State state(String stateId, boolean alwaysFetched);

/** Returns the timer for the given {@link TimerId}. */
Timer timer(String timerId);
Expand Down Expand Up @@ -313,7 +313,7 @@ public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT>
}

@Override
public State state(String stateId) {
public State state(String stateId, boolean alwaysFetched) {
throw new UnsupportedOperationException(
String.format("State unsupported in %s", getErrorContext()));
}
Expand Down Expand Up @@ -436,8 +436,8 @@ public Object restriction() {
}

@Override
public State state(String stateId) {
return delegate.state(stateId);
public State state(String stateId, boolean alwaysFetch) {
return delegate.state(stateId, alwaysFetch);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,8 @@ public static RestrictionTrackerParameter restrictionTracker(TypeDescriptor<?> t
}

/** Returns a {@link StateParameter} referring to the given {@link StateDeclaration}. */
public static StateParameter stateParameter(StateDeclaration decl) {
return new AutoValue_DoFnSignature_Parameter_StateParameter(decl);
public static StateParameter stateParameter(StateDeclaration decl, boolean alwaysFetched) {
return new AutoValue_DoFnSignature_Parameter_StateParameter(decl, alwaysFetched);
}

public static TimerParameter timerParameter(TimerDeclaration decl) {
Expand Down Expand Up @@ -756,6 +756,8 @@ public abstract static class StateParameter extends Parameter {
StateParameter() {}

public abstract StateDeclaration referent();

public abstract boolean alwaysFetched();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.ReadableState;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.state.StateSpec;
Expand Down Expand Up @@ -1258,7 +1259,14 @@ private static Parameter analyzeExtraParameter(
id,
stateDecl.field().getDeclaringClass().getName());

return Parameter.stateParameter(stateDecl);
boolean alwaysFetched = getStateAlwaysFetched(param.getAnnotations());
if (alwaysFetched) {
paramErrors.checkArgument(
ReadableState.class.isAssignableFrom(rawType),
"@AlwaysFetched can only be used on ReadableStates. It cannot be used on %s",
format(stateDecl.stateType()));
}
return Parameter.stateParameter(stateDecl, alwaysFetched);
} else {
paramErrors.throwIllegalArgument("%s is not a valid context parameter.", format(paramT));
// Unreachable
Expand All @@ -1284,6 +1292,11 @@ private static String getStateId(List<Annotation> annotations) {
return stateId != null ? stateId.value() : null;
}

private static boolean getStateAlwaysFetched(List<Annotation> annotations) {
DoFn.AlwaysFetched alwaysFetched = findFirstOfType(annotations, DoFn.AlwaysFetched.class);
return alwaysFetched != null;
}

@Nullable
private static String getFieldAccessId(List<Annotation> annotations) {
DoFn.FieldAccess access = findFirstOfType(annotations, DoFn.FieldAccess.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1716,7 +1716,8 @@ public void testValueStateSimple() {

@ProcessElement
public void processElement(
@StateId(stateId) ValueState<Integer> state, OutputReceiver<Integer> r) {
@AlwaysFetched @StateId(stateId) ValueState<Integer> state,
OutputReceiver<Integer> r) {
Integer currentValue = MoreObjects.firstNonNull(state.read(), 0);
r.output(currentValue);
state.write(currentValue + 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ public void processElement(
public void testDoFnWithState() throws Exception {
ValueState<Integer> mockState = mock(ValueState.class);
final String stateId = "my-state-id-here";
when(mockArgumentProvider.state(stateId)).thenReturn(mockState);
when(mockArgumentProvider.state(stateId, false)).thenReturn(mockState);

class MockFn extends DoFn<String, String> {
@StateId(stateId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import java.lang.reflect.Field;
Expand Down Expand Up @@ -836,6 +837,43 @@ public void myProcessElement(
}.getClass());
}

@Test
public void testStateParameterAlwaysFetched() {
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("ReadableStates");
DoFnSignature sig =
DoFnSignatures.getSignature(
new DoFn<KV<String, Integer>, Long>() {
@StateId("my-id")
private final StateSpec<MapState<Integer, Integer>> myfield =
StateSpecs.map(VarIntCoder.of(), VarIntCoder.of());

@ProcessElement
public void myProcessElement(
ProcessContext context,
@AlwaysFetched @StateId("my-id") MapState<Integer, Integer> one) {}
}.getClass());
StateParameter stateParameter = (StateParameter) sig.processElement().extraParameters().get(1);
assertTrue(stateParameter.alwaysFetched());
}

@Test
public void testStateParameterAlwaysFetchNonReadableState() {
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("ReadableStates");
DoFnSignatures.getSignature(
new DoFn<KV<String, Integer>, Long>() {
@StateId("my-id")
private final StateSpec<MapState<Integer, Integer>> myfield =
StateSpecs.map(VarIntCoder.of(), VarIntCoder.of());

@ProcessElement
public void myProcessElement(
ProcessContext context,
@AlwaysFetched @StateId("my-id") MapState<Integer, Integer> one) {}
}.getClass());
}

@Test
public void testStateParameterDuplicate() throws Exception {
thrown.expect(IllegalArgumentException.class);
Expand Down Expand Up @@ -987,7 +1025,7 @@ public void processWithState(ProcessContext c, ValueState<String> state) {}
DoFnSignature.StateDeclaration decl =
sig.stateDeclarations().get(DoFnOverridingAbstractStateUse.STATE_ID);
StateParameter stateParam = (StateParameter) sig.processElement().extraParameters().get(1);

assertFalse(stateParam.alwaysFetched());
assertThat(
decl.field(),
equalTo(DoFnDeclaringStateAndAbstractUse.class.getDeclaredField("myStateSpec")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.apache.beam.sdk.function.ThrowingRunnable;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.state.ReadableState;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.TimeDomain;
Expand Down Expand Up @@ -1067,7 +1068,7 @@ public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT>
}

@Override
public State state(String stateId) {
public State state(String stateId, boolean alwaysFetched) {
StateDeclaration stateDeclaration = context.doFnSignature.stateDeclarations().get(stateId);
checkNotNull(stateDeclaration, "No state declaration found for %s", stateId);
StateSpec<?> spec;
Expand All @@ -1076,7 +1077,12 @@ public State state(String stateId) {
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
return spec.bind(stateId, stateAccessor);
State state = spec.bind(stateId, stateAccessor);
if (alwaysFetched) {
return (State) ((ReadableState) state).readLater();
} else {
return state;
}
}

@Override
Expand Down Expand Up @@ -1258,7 +1264,7 @@ public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT>
}

@Override
public State state(String stateId) {
public State state(String stateId, boolean alwaysFetched) {
StateDeclaration stateDeclaration = context.doFnSignature.stateDeclarations().get(stateId);
checkNotNull(stateDeclaration, "No state declaration found for %s", stateId);
StateSpec<?> spec;
Expand All @@ -1267,7 +1273,12 @@ public State state(String stateId) {
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
return spec.bind(stateId, stateAccessor);
State state = spec.bind(stateId, stateAccessor);
if (alwaysFetched) {
return (State) ((ReadableState) state).readLater();
} else {
return state;
}
}

@Override
Expand Down

0 comments on commit 52fea0e

Please sign in to comment.