You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
4.3 KiB
120 lines
4.3 KiB
from fastapi import FastAPI
|
|
from pydantic import BaseModel
|
|
import datetime
|
|
import torch
|
|
from PIL import Image
|
|
import numpy as np
|
|
import SpoutGL
|
|
from OpenGL.GL import GL_RGBA
|
|
import time
|
|
import img2img
|
|
|
|
def main():
|
|
TARGET_FPS = 60
|
|
SPOUT_RECEIVER_NAME = "Spout DX11 Sender"
|
|
SPOUT_SENDER_NAME = "Output - StreamDiffusion"
|
|
WIDTH = 512
|
|
HEIGHT = 512
|
|
PROMPT = "a beautiful landscape painting, trending on artstation, 8k, hyperrealistic"
|
|
timestamp = datetime.datetime.now()
|
|
fps = 30.0
|
|
|
|
print("Initializing StreamDiffusion pipeline...")
|
|
global pipeline
|
|
try:
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
torch_dtype = torch.float16
|
|
pipeline = img2img.Pipeline(device, torch_dtype)
|
|
|
|
app = FastAPI()
|
|
|
|
@app.get("/health")
|
|
def read_root():
|
|
return {"status": "ok"}
|
|
|
|
class PromptUpdate(BaseModel):
|
|
prompt: str
|
|
|
|
@app.post("/api/update/prompt")
|
|
async def update_prompt(update: PromptUpdate):
|
|
global PROMPT
|
|
PROMPT = update.prompt
|
|
print(f"Prompt updated to: {PROMPT}")
|
|
return {"message": "Prompt updated successfully", "new_prompt": PROMPT}
|
|
|
|
print("Pipeline initialized.")
|
|
except Exception as e:
|
|
print(f"Error initializing StreamDiffusion pipeline: {e}")
|
|
return
|
|
|
|
print(f"Initializing Spout receiver for '{SPOUT_RECEIVER_NAME}'...")
|
|
spout_receiver = SpoutGL.SpoutReceiver()
|
|
spout_receiver.setReceiverName(SPOUT_RECEIVER_NAME)
|
|
|
|
print(f"Initializing Spout sender as '{SPOUT_SENDER_NAME}'...")
|
|
spout_sender = SpoutGL.SpoutSender()
|
|
spout_sender.setSenderName(SPOUT_SENDER_NAME)
|
|
|
|
image_bgra = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
|
|
|
|
import uvicorn
|
|
import threading
|
|
config = uvicorn.Config(app, host="0.0.0.0", port=34800, log_level="info")
|
|
server = uvicorn.Server(config)
|
|
threading.Thread(target=server.run, daemon=True).start()
|
|
print("FastAPI server started at http://0.0.0.0:34800")
|
|
|
|
try:
|
|
print("Starting main loop. Press Ctrl+C to exit.")
|
|
while True:
|
|
received = spout_receiver.receiveImage(image_bgra, GL_RGBA, False, 0)
|
|
# print(f"Received: {received}, Connected: {spout_receiver.isConnected()}, Updated: {spout_receiver.isUpdated()}, Empty: {SpoutGL.helpers.isBufferEmpty(image_bgra)}")
|
|
|
|
if received:
|
|
if spout_receiver.isUpdated():
|
|
continue
|
|
|
|
if spout_receiver.isConnected() and SpoutGL.helpers.isBufferEmpty(image_bgra):
|
|
continue
|
|
|
|
image_rgb_array = image_bgra[:, :, [2,1,0]]
|
|
input_image = Image.fromarray(image_rgb_array, 'RGB')
|
|
# input_image.save("debug_input.png")
|
|
|
|
|
|
params = img2img.Pipeline.InputParams(prompt=PROMPT)
|
|
output_image = pipeline.predict(image=input_image, params=params)
|
|
# output_image.save("debug_output.png")
|
|
|
|
# output_rgba_array = np.array(output_image.convert("RGBA"))
|
|
# output_bgra_array = output_rgba_array[:, :, [2, 1, 0, 3]]
|
|
# buffer = np.ascontiguousarray(output_bgra_array)
|
|
output_bgr_array = np.array(output_image, dtype=np.uint8)[:, :, ::-1]
|
|
output_bgra_array = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
|
|
output_bgra_array[:, :, :3] = output_bgr_array
|
|
output_bgra_array[:, :, 3] = 255
|
|
buffer = output_bgra_array
|
|
|
|
spout_sender.sendImage(buffer, WIDTH, HEIGHT, GL_RGBA, False, 0)
|
|
|
|
# timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
|
dt = (datetime.datetime.now() - timestamp).total_seconds()
|
|
t = 0.05
|
|
fps = fps * t + 1 / dt * (1 - t)
|
|
timestamp = datetime.datetime.now()
|
|
|
|
print("\033[92m[ STREAM DIFFUSION ]\033[0m " + f"Frame processed and sent to Spout: {fps:2f}", end="\r", flush=True)
|
|
else:
|
|
time.sleep(1. / TARGET_FPS)
|
|
|
|
except KeyboardInterrupt:
|
|
print("\nExiting...")
|
|
finally:
|
|
print("Releasing Spout resources.")
|
|
spout_receiver.releaseReceiver()
|
|
spout_sender.releaseSender()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|