Skip to content

Commit

Permalink
Add :on_cancel
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Mar 20, 2023
1 parent c31cb55 commit 498b16a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 8 deletions.
6 changes: 6 additions & 0 deletions lib/gen_stage.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,12 @@ defmodule GenStage do
* `:stacktrace` - the stacktrace of the function that started the
stream.
* `:on_cancel` - what happens when all consumers cancel. The default
is to keep the stream running. Set it to `:stop` to stop the producer.
To avoid race conditions, it is recommend to only set this option if
`:demand` is set to `:accumulate` and forwarded only after all consumers
subscribe
All other options that would be given for `start_link/3` are
also accepted.
"""
Expand Down
43 changes: 35 additions & 8 deletions lib/gen_stage/streamer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,60 @@ defmodule GenStage.Streamer do
x, {acc, counter} -> {:cont, {[x | acc], counter - 1}}
end)

{:producer, {stack, continuation}, Keyword.take(opts, [:dispatcher, :demand])}
on_cancel =
case Keyword.get(opts, :on_cancel, :continue) do
:continue -> nil
:stop -> %{}
end

{:producer, {stack, continuation, on_cancel}, Keyword.take(opts, [:dispatcher, :demand])}
end

def handle_subscribe(:consumer, _opts, {pid, ref}, {stack, continuation, on_cancel}) do
if on_cancel do
{:automatic, {stack, continuation, Map.put(on_cancel, ref, pid)}}
else
{:automatic, {stack, continuation, on_cancel}}
end
end

def handle_cancel(_reason, {_, ref}, {stack, continuation, on_cancel}) do
case on_cancel do
%{^ref => _} when map_size(on_cancel) == 1 ->
{:stop, :normal, {stack, continuation, Map.delete(on_cancel, ref)}}

%{^ref => _} ->
{:noreply, [], {stack, continuation, Map.delete(on_cancel, ref)}}

_ ->
{:noreply, [], {stack, continuation, on_cancel}}
end
end

def handle_demand(_demand, {stack, continuation}) when is_atom(continuation) do
{:noreply, [], {stack, continuation}}
def handle_demand(_demand, {stack, continuation, on_cancel}) when is_atom(continuation) do
{:noreply, [], {stack, continuation, on_cancel}}
end

def handle_demand(demand, {stack, continuation}) when demand > 0 do
def handle_demand(demand, {stack, continuation, on_cancel}) when demand > 0 do
case continuation.({:cont, {[], demand}}) do
{:suspended, {list, 0}, continuation} ->
{:noreply, :lists.reverse(list), {stack, continuation}}
{:noreply, :lists.reverse(list), {stack, continuation, on_cancel}}

{status, {list, _}} ->
GenStage.async_info(self(), :stop)
{:noreply, :lists.reverse(list), {stack, status}}
{:noreply, :lists.reverse(list), {stack, status, on_cancel}}
end
end

def handle_info(:stop, state) do
{:stop, :normal, state}
end

def handle_info(msg, {stack, continuation}) do
def handle_info(msg, {stack, continuation, on_cancel}) do
log =
~c"** Undefined handle_info in ~tp~n** Unhandled message: ~tp~n** Stream started at:~n~ts"

:error_logger.warning_msg(log, [inspect(__MODULE__), msg, Exception.format_stacktrace(stack)])
{:noreply, [], {stack, continuation}}
{:noreply, [], {stack, continuation, on_cancel}}
end
end
15 changes: 15 additions & 0 deletions test/gen_stage_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1905,6 +1905,21 @@ defmodule GenStageTest do
assert Process.info(producer, :registered_name) ==
{:registered_name, :gen_stage_from_enumerable}
end

test "accepts a :on_cancel option" do
{:ok, pid} = GenStage.from_enumerable(Stream.cycle([1, 2, 3]))
assert [pid] |> GenStage.stream() |> Enum.take(5) == [1, 2, 3, 1, 2]
assert Process.alive?(pid)

{:ok, pid} = GenStage.from_enumerable(Stream.cycle([1, 2, 3]), on_cancel: :continue)
assert [pid] |> GenStage.stream() |> Enum.take(5) == [1, 2, 3, 1, 2]
assert Process.alive?(pid)

{:ok, pid} = GenStage.from_enumerable(Stream.cycle([1, 2, 3]), on_cancel: :stop)
assert [pid] |> GenStage.stream() |> Enum.take(5) == [1, 2, 3, 1, 2]
ref = Process.monitor(pid)
assert_receive {:DOWN, ^ref, _, _, _}
end
end

describe "subscribe_to names" do
Expand Down

0 comments on commit 498b16a

Please sign in to comment.