You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

134 lines
4.9 KiB

5 months ago
from fastapi import FastAPI
from pydantic import BaseModel
import datetime
import torch
from PIL import Image
import numpy as np
import SpoutGL
5 months ago
from OpenGL.GL import GL_RGBA, GL_BGRA
5 months ago
import time
import img2img
5 months ago
from multiprocessing import Queue
5 months ago
def main():
TARGET_FPS = 60
5 months ago
SPOUT_RECEIVER_NAME = "NoiseSender"
SPOUT_SENDER_NAME = "StreamDiffusionSender"
5 months ago
WIDTH = 512
HEIGHT = 512
PROMPT = "a beautiful landscape painting, trending on artstation, 8k, hyperrealistic"
timestamp = datetime.datetime.now()
fps = 30.0
5 months ago
prompt_queue = Queue()
5 months ago
print("Initializing StreamDiffusion pipeline...")
global pipeline
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16
pipeline = img2img.Pipeline(device, torch_dtype)
app = FastAPI()
@app.get("/health")
def read_root():
return {"status": "ok"}
class PromptUpdate(BaseModel):
prompt: str
@app.post("/api/update/prompt")
async def update_prompt(update: PromptUpdate):
global PROMPT
PROMPT = update.prompt
5 months ago
prompt_queue.put(PROMPT)
5 months ago
print(f"Prompt updated to: {PROMPT}")
return {"message": "Prompt updated successfully", "new_prompt": PROMPT}
print("Pipeline initialized.")
except Exception as e:
print(f"Error initializing StreamDiffusion pipeline: {e}")
return
print(f"Initializing Spout receiver for '{SPOUT_RECEIVER_NAME}'...")
spout_receiver = SpoutGL.SpoutReceiver()
spout_receiver.setReceiverName(SPOUT_RECEIVER_NAME)
print(f"Initializing Spout sender as '{SPOUT_SENDER_NAME}'...")
spout_sender = SpoutGL.SpoutSender()
spout_sender.setSenderName(SPOUT_SENDER_NAME)
image_bgra = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
import uvicorn
import threading
config = uvicorn.Config(app, host="0.0.0.0", port=34800, log_level="info")
server = uvicorn.Server(config)
threading.Thread(target=server.run, daemon=True).start()
print("FastAPI server started at http://0.0.0.0:34800")
try:
print("Starting main loop. Press Ctrl+C to exit.")
while True:
received = spout_receiver.receiveImage(image_bgra, GL_RGBA, False, 0)
# print(f"Received: {received}, Connected: {spout_receiver.isConnected()}, Updated: {spout_receiver.isUpdated()}, Empty: {SpoutGL.helpers.isBufferEmpty(image_bgra)}")
if received:
if spout_receiver.isUpdated():
continue
if spout_receiver.isConnected() and SpoutGL.helpers.isBufferEmpty(image_bgra):
continue
5 months ago
image_rgb_array = image_bgra[:, :, [2,1,0]]
5 months ago
image_rgb_array = (image_rgb_array+ 1.0 )/2.0
5 months ago
input_image = Image.fromarray(image_rgb_array, 'RGB')
5 months ago
# input_image.save("debug_input.png")
5 months ago
if not prompt_queue.empty():
new_prompt = prompt_queue.get(block=False)
if new_prompt:
print(f"Received new prompt from queue: {new_prompt}")
PROMPT = new_prompt
5 months ago
5 months ago
# print(f"current prompt: {PROMPT}")
5 months ago
params = img2img.Pipeline.InputParams(prompt=PROMPT)
output_image = pipeline.predict(image=input_image, params=params)
# output_image.save("debug_output.png")
# output_rgba_array = np.array(output_image.convert("RGBA"))
# output_bgra_array = output_rgba_array[:, :, [2, 1, 0, 3]]
# buffer = np.ascontiguousarray(output_bgra_array)
5 months ago
# output_bgr_array = np.array(output_image, dtype=np.uint8)[:, :, ::-1]
output_bgr_array=np.array(output_image)
5 months ago
output_bgra_array = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
output_bgra_array[:, :, :3] = output_bgr_array
output_bgra_array[:, :, 3] = 255
buffer = output_bgra_array
spout_sender.sendImage(buffer, WIDTH, HEIGHT, GL_RGBA, False, 0)
# timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
dt = (datetime.datetime.now() - timestamp).total_seconds()
t = 0.05
fps = fps * t + 1 / dt * (1 - t)
timestamp = datetime.datetime.now()
print("\033[92m[ STREAM DIFFUSION ]\033[0m " + f"Frame processed and sent to Spout: {fps:2f}", end="\r", flush=True)
else:
time.sleep(1. / TARGET_FPS)
except KeyboardInterrupt:
print("\nExiting...")
finally:
print("Releasing Spout resources.")
spout_receiver.releaseReceiver()
spout_sender.releaseSender()
if __name__ == "__main__":
main()