parent
f1d8b043ea
commit
028f10f0ec
6 changed files with 297 additions and 41 deletions
@ -0,0 +1,112 @@ |
|||||||
|
import sys |
||||||
|
import os |
||||||
|
|
||||||
|
sys.path.append( |
||||||
|
os.path.join( |
||||||
|
os.path.dirname(__file__), |
||||||
|
"..", |
||||||
|
"..", |
||||||
|
) |
||||||
|
) |
||||||
|
|
||||||
|
from utils.wrapper import StreamDiffusionWrapper |
||||||
|
|
||||||
|
import torch |
||||||
|
|
||||||
|
# from config import Args |
||||||
|
from pydantic import BaseModel, Field |
||||||
|
from PIL import Image |
||||||
|
import math |
||||||
|
|
||||||
|
# base_model = "stabilityai/sd-turbo" |
||||||
|
# taesd_model = "madebyollin/taesd" |
||||||
|
base_model = "./models/sd-turbo" |
||||||
|
taesd_model = "./models/taesd" |
||||||
|
|
||||||
|
default_prompt = "Portrait of The Joker halloween costume, face painting, with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece" |
||||||
|
default_negative_prompt = "black and white, blurry, low resolution, pixelated, pixel art, low quality, low fidelity" |
||||||
|
|
||||||
|
page_content = """<h1 class="text-3xl font-bold">StreamDiffusion</h1> |
||||||
|
<h3 class="text-xl font-bold">Image-to-Image SD-Turbo</h3> |
||||||
|
<p class="text-sm"> |
||||||
|
This demo showcases |
||||||
|
<a |
||||||
|
href="https://github.com/cumulo-autumn/StreamDiffusion" |
||||||
|
target="_blank" |
||||||
|
class="text-blue-500 underline hover:no-underline">StreamDiffusion |
||||||
|
</a> |
||||||
|
Image to Image pipeline using |
||||||
|
<a |
||||||
|
href="https://huggingface.co/stabilityai/sd-turbo" |
||||||
|
target="_blank" |
||||||
|
class="text-blue-500 underline hover:no-underline">SD-Turbo</a |
||||||
|
> with a MJPEG stream server. |
||||||
|
</p> |
||||||
|
""" |
||||||
|
|
||||||
|
|
||||||
|
class Pipeline: |
||||||
|
class Info(BaseModel): |
||||||
|
name: str = "StreamDiffusion img2img" |
||||||
|
input_mode: str = "image" |
||||||
|
page_content: str = page_content |
||||||
|
|
||||||
|
class InputParams(BaseModel): |
||||||
|
prompt: str = Field( |
||||||
|
default_prompt, |
||||||
|
title="Prompt", |
||||||
|
field="textarea", |
||||||
|
id="prompt", |
||||||
|
) |
||||||
|
# negative_prompt: str = Field( |
||||||
|
# default_negative_prompt, |
||||||
|
# title="Negative Prompt", |
||||||
|
# field="textarea", |
||||||
|
# id="negative_prompt", |
||||||
|
# ) |
||||||
|
width: int = Field( |
||||||
|
512, min=2, max=15, title="Width", disabled=True, hide=True, id="width" |
||||||
|
) |
||||||
|
height: int = Field( |
||||||
|
512, min=2, max=15, title="Height", disabled=True, hide=True, id="height" |
||||||
|
) |
||||||
|
|
||||||
|
def __init__(self, device: torch.device, torch_dtype: torch.dtype): |
||||||
|
params = self.InputParams() |
||||||
|
self.stream = StreamDiffusionWrapper( |
||||||
|
model_id_or_path=base_model, |
||||||
|
use_tiny_vae=True, |
||||||
|
device=device, |
||||||
|
dtype=torch_dtype, |
||||||
|
t_index_list=[35, 45], |
||||||
|
frame_buffer_size=1, |
||||||
|
width=params.width, |
||||||
|
height=params.height, |
||||||
|
use_lcm_lora=False, |
||||||
|
output_type="pil", |
||||||
|
warmup=10, |
||||||
|
vae_id=taesd_model, |
||||||
|
acceleration="xformers", |
||||||
|
mode="img2img", |
||||||
|
use_denoising_batch=True, |
||||||
|
cfg_type="none", |
||||||
|
# use_safety_checker=args.safety_checker, |
||||||
|
enable_similar_image_filter=True, |
||||||
|
similar_image_filter_threshold=0.98, |
||||||
|
# engine_dir=args.engine_dir, |
||||||
|
) |
||||||
|
|
||||||
|
self.last_prompt = default_prompt |
||||||
|
self.stream.prepare( |
||||||
|
prompt=default_prompt, |
||||||
|
negative_prompt=default_negative_prompt, |
||||||
|
num_inference_steps=50, |
||||||
|
guidance_scale=1.2, |
||||||
|
) |
||||||
|
|
||||||
|
def predict(self, image: Image.Image, params: "Pipeline.InputParams") -> Image.Image: |
||||||
|
image_tensor = self.stream.preprocess_image(image) |
||||||
|
# output_image = self.stream(image=image_tensor, prompt=params.prompt) |
||||||
|
output_image = self.stream(image=image_tensor, prompt=params.prompt) |
||||||
|
|
||||||
|
return output_image |
||||||
@ -0,0 +1,120 @@ |
|||||||
|
from fastapi import FastAPI |
||||||
|
from pydantic import BaseModel |
||||||
|
import datetime |
||||||
|
import torch |
||||||
|
from PIL import Image |
||||||
|
import numpy as np |
||||||
|
import SpoutGL |
||||||
|
from OpenGL.GL import GL_RGBA |
||||||
|
import time |
||||||
|
import img2img |
||||||
|
|
||||||
|
def main(): |
||||||
|
TARGET_FPS = 60 |
||||||
|
SPOUT_RECEIVER_NAME = "Spout DX11 Sender" |
||||||
|
SPOUT_SENDER_NAME = "Output - StreamDiffusion" |
||||||
|
WIDTH = 512 |
||||||
|
HEIGHT = 512 |
||||||
|
PROMPT = "a beautiful landscape painting, trending on artstation, 8k, hyperrealistic" |
||||||
|
timestamp = datetime.datetime.now() |
||||||
|
fps = 30.0 |
||||||
|
|
||||||
|
print("Initializing StreamDiffusion pipeline...") |
||||||
|
global pipeline |
||||||
|
try: |
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
||||||
|
torch_dtype = torch.float16 |
||||||
|
pipeline = img2img.Pipeline(device, torch_dtype) |
||||||
|
|
||||||
|
app = FastAPI() |
||||||
|
|
||||||
|
@app.get("/health") |
||||||
|
def read_root(): |
||||||
|
return {"status": "ok"} |
||||||
|
|
||||||
|
class PromptUpdate(BaseModel): |
||||||
|
prompt: str |
||||||
|
|
||||||
|
@app.post("/api/update/prompt") |
||||||
|
async def update_prompt(update: PromptUpdate): |
||||||
|
global PROMPT |
||||||
|
PROMPT = update.prompt |
||||||
|
print(f"Prompt updated to: {PROMPT}") |
||||||
|
return {"message": "Prompt updated successfully", "new_prompt": PROMPT} |
||||||
|
|
||||||
|
print("Pipeline initialized.") |
||||||
|
except Exception as e: |
||||||
|
print(f"Error initializing StreamDiffusion pipeline: {e}") |
||||||
|
return |
||||||
|
|
||||||
|
print(f"Initializing Spout receiver for '{SPOUT_RECEIVER_NAME}'...") |
||||||
|
spout_receiver = SpoutGL.SpoutReceiver() |
||||||
|
spout_receiver.setReceiverName(SPOUT_RECEIVER_NAME) |
||||||
|
|
||||||
|
print(f"Initializing Spout sender as '{SPOUT_SENDER_NAME}'...") |
||||||
|
spout_sender = SpoutGL.SpoutSender() |
||||||
|
spout_sender.setSenderName(SPOUT_SENDER_NAME) |
||||||
|
|
||||||
|
image_bgra = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8) |
||||||
|
|
||||||
|
import uvicorn |
||||||
|
import threading |
||||||
|
config = uvicorn.Config(app, host="0.0.0.0", port=34800, log_level="info") |
||||||
|
server = uvicorn.Server(config) |
||||||
|
threading.Thread(target=server.run, daemon=True).start() |
||||||
|
print("FastAPI server started at http://0.0.0.0:34800") |
||||||
|
|
||||||
|
try: |
||||||
|
print("Starting main loop. Press Ctrl+C to exit.") |
||||||
|
while True: |
||||||
|
received = spout_receiver.receiveImage(image_bgra, GL_RGBA, False, 0) |
||||||
|
# print(f"Received: {received}, Connected: {spout_receiver.isConnected()}, Updated: {spout_receiver.isUpdated()}, Empty: {SpoutGL.helpers.isBufferEmpty(image_bgra)}") |
||||||
|
|
||||||
|
if received: |
||||||
|
if spout_receiver.isUpdated(): |
||||||
|
continue |
||||||
|
|
||||||
|
if spout_receiver.isConnected() and SpoutGL.helpers.isBufferEmpty(image_bgra): |
||||||
|
continue |
||||||
|
|
||||||
|
image_rgb_array = image_bgra[:, :, [2,1,0]] |
||||||
|
input_image = Image.fromarray(image_rgb_array, 'RGB') |
||||||
|
# input_image.save("debug_input.png") |
||||||
|
|
||||||
|
|
||||||
|
params = img2img.Pipeline.InputParams(prompt=PROMPT) |
||||||
|
output_image = pipeline.predict(image=input_image, params=params) |
||||||
|
# output_image.save("debug_output.png") |
||||||
|
|
||||||
|
# output_rgba_array = np.array(output_image.convert("RGBA")) |
||||||
|
# output_bgra_array = output_rgba_array[:, :, [2, 1, 0, 3]] |
||||||
|
# buffer = np.ascontiguousarray(output_bgra_array) |
||||||
|
output_bgr_array = np.array(output_image, dtype=np.uint8)[:, :, ::-1] |
||||||
|
output_bgra_array = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8) |
||||||
|
output_bgra_array[:, :, :3] = output_bgr_array |
||||||
|
output_bgra_array[:, :, 3] = 255 |
||||||
|
buffer = output_bgra_array |
||||||
|
|
||||||
|
spout_sender.sendImage(buffer, WIDTH, HEIGHT, GL_RGBA, False, 0) |
||||||
|
|
||||||
|
# timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] |
||||||
|
dt = (datetime.datetime.now() - timestamp).total_seconds() |
||||||
|
t = 0.05 |
||||||
|
fps = fps * t + 1 / dt * (1 - t) |
||||||
|
timestamp = datetime.datetime.now() |
||||||
|
|
||||||
|
print("\033[92m[ STREAM DIFFUSION ]\033[0m " + f"Frame processed and sent to Spout: {fps:2f}", end="\r", flush=True) |
||||||
|
else: |
||||||
|
time.sleep(1. / TARGET_FPS) |
||||||
|
|
||||||
|
except KeyboardInterrupt: |
||||||
|
print("\nExiting...") |
||||||
|
finally: |
||||||
|
print("Releasing Spout resources.") |
||||||
|
spout_receiver.releaseReceiver() |
||||||
|
spout_sender.releaseSender() |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
main() |
||||||
|
|
||||||
Loading…
Reference in new issue