Skip to content

Commit

Permalink
[examples] fix multiprocessing on Linux
Browse files Browse the repository at this point in the history
- use multiprocessing context to specify the spawn start method, which fixes the "RuntimeError: Cannot re-initialize CUDA in forked subprocess" on Linux (verified with Ubuntu 22.04 and kernel 6.2.0-39)
- call `.join()` to wait for processes to complete (avoids exiting program immediately)
  • Loading branch information
GradientSurfer committed Dec 24, 2023
1 parent 03e2a7f commit a06fbd1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
13 changes: 8 additions & 5 deletions examples/optimal-performance/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import threading
import time
import tkinter as tk
from multiprocessing import Process, Queue
from multiprocessing import Process, Queue, get_context
from typing import List, Literal

import fire
Expand Down Expand Up @@ -174,17 +174,20 @@ def main(
"""
Main function to start the image generation and viewer processes.
"""
queue = Queue()
fps_queue = Queue()
process1 = Process(
ctx = get_context('spawn')
queue = ctx.Queue()
fps_queue = ctx.Queue()
process1 = ctx.Process(
target=image_generation_process,
args=(queue, fps_queue, prompt, model_id_or_path, batch_size, acceleration),
)
process1.start()

process2 = Process(target=receive_images, args=(queue, fps_queue))
process2 = ctx.Process(target=receive_images, args=(queue, fps_queue))
process2.start()

process1.join()
process2.join()

if __name__ == "__main__":
fire.Fire(main)
13 changes: 8 additions & 5 deletions examples/optimal-performance/single.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import sys
import time
from multiprocessing import Process, Queue
from multiprocessing import Process, Queue, get_context
from typing import Literal

import fire
Expand Down Expand Up @@ -72,17 +72,20 @@ def main(
"""
Main function to start the image generation and viewer processes.
"""
queue = Queue()
fps_queue = Queue()
process1 = Process(
ctx = get_context('spawn')
queue = ctx.Queue()
fps_queue = ctx.Queue()
process1 = ctx.Process(
target=image_generation_process,
args=(queue, fps_queue, prompt, model_id_or_path, acceleration),
)
process1.start()

process2 = Process(target=receive_images, args=(queue, fps_queue))
process2 = ctx.Process(target=receive_images, args=(queue, fps_queue))
process2.start()

process1.join()
process2.join()

if __name__ == "__main__":
fire.Fire(main)
14 changes: 8 additions & 6 deletions examples/screen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import time
import threading
from multiprocessing import Process, Queue
from multiprocessing import Process, Queue, get_context
from typing import List, Literal, Dict, Optional
import torch
import PIL.Image
Expand Down Expand Up @@ -216,10 +216,10 @@ def main(
Main function to start the image generation and viewer processes.
"""
monitor = dummy_screen(width, height)

queue = Queue()
fps_queue = Queue()
process1 = Process(
ctx = get_context('spawn')
queue = ctx.Queue()
fps_queue = ctx.Queue()
process1 = ctx.Process(
target=image_generation_process,
args=(
queue,
Expand All @@ -246,9 +246,11 @@ def main(
)
process1.start()

process2 = Process(target=receive_images, args=(queue, fps_queue))
process2 = ctx.Process(target=receive_images, args=(queue, fps_queue))
process2.start()

process1.join()
process2.join()

if __name__ == "__main__":
fire.Fire(main)

0 comments on commit a06fbd1

Please sign in to comment.