import argparse
import sys
import traceback
from io import BytesIO
from pathlib import Path
import light_side as ls
import numpy as np
import torch
import uvicorn
from PIL import Image
from starlette.responses import RedirectResponse, StreamingResponse
from fastapi import FastAPI, File, HTTPException, Query, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
tags_metadata = [
{
"name": "Enhance",
"description": "Enhancer dark images to light images",
"externalDocs": {
"description": "External Docs for Library: ",
"url": "https://light-side.readthedocs.io/",
},
},
]
app = FastAPI(
title="Light Side", swagger_ui_parameters={"defaultModelsExpandDepth": -1}
)
def custom_openapi():
openapi_schema = get_openapi(
title="Light Side API",
version=ls.__version__,
description=ls.__description__,
routes=app.routes,
tags=tags_metadata,
license_info={
"name": ls.__license__,
"url": ls.__license_url__,
},
contact={
"name": ls.__author__,
},
)
openapi_schema["info"]["x-logo"] = {
"url": "https://raw.githubusercontent.com/canturan10/light_side/master/src/light_side.png"
}
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_LIST = []
for model in ls.available_models():
for version in ls.get_model_versions(model):
MODEL_LIST.append(f"{model}-{version}")
@app.on_event("startup")
def load_artifacts():
app.state.models = {}
for model in ls.available_models():
for version in ls.get_model_versions(model):
app.state.models[f"{model}-{version}"] = ls.Enhancer.from_pretrained(
model,
version,
)
app.state.models[f"{model}-{version}"].eval()
app.state.models[f"{model}-{version}"].to(
"cuda" if torch.cuda.is_available() else "cpu"
)
@app.on_event("shutdown")
def empty_cache():
# clear Cuda memory
torch.cuda.empty_cache()
def read_imagefile(data) -> Image.Image:
image = Image.open(BytesIO(data))
return image
@app.get("/", include_in_schema=False)
def main():
return RedirectResponse(url="/docs")
@app.post("/enhance/", tags=["Enhance"], summary="Low Light Image Enhancement")
async def enhance(
file: UploadFile = File(...),
model: str = Query(MODEL_LIST[0], enum=MODEL_LIST),
):
if file.content_type.startswith("image/") is False:
raise HTTPException(
status_code=400,
detail=f"File '{file.filename}' is not an image.",
)
try:
content = await file.read()
image = np.array(read_imagefile(content).convert("RGB"))
results = app.state.models[model].predict(image)
enhanced_image = Image.fromarray(results[0]["enhanced"])
buf = BytesIO()
enhanced_image.save(buf, format="PNG")
byte_im = buf.getvalue()
return StreamingResponse(BytesIO(byte_im), media_type="image/png")
except Exception:
sys.stdout.flush()
e_info = sys.exc_info()[1]
traceback.print_exc()
return HTTPException(
status_code=500,
detail=str(e_info),
headers={"Content-Type": "text/plain"},
)
if __name__ == "__main__":
from pathlib import Path
parser = argparse.ArgumentParser(description="Runs the API server.")
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to run the API on.",
)
parser.add_argument(
"--port",
help="The port to listen for requests on.",
type=int,
default=8500,
)
parser.add_argument(
"--workers",
help="Number of workers to use.",
type=int,
default=1,
)
parser.add_argument(
"--reload",
help="Reload the model on each request.",
action="store_true",
)
parser.add_argument(
"--use-colors",
help="Enable user-friendly color output.",
action="store_true",
)
args = parser.parse_args()
uvicorn.run(
f"{Path(__file__).stem}:app",
host=args.host,
port=args.port,
workers=args.workers,
reload=args.reload,
use_colors=args.use_colors,
)