-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi.py
More file actions
111 lines (91 loc) · 3.29 KB
/
api.py
File metadata and controls
111 lines (91 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse
import shutil
import os
import cv2
import numpy as np
from pathlib import Path
import tempfile
from typing import List
import torch
from model import HDRUNet
import torch.nn.functional as F
app = FastAPI()
# Initialize model
def load_model(model_path='./pretrained_model.pth'):
model = HDRUNet()
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
return model
model = None
@app.on_event("startup")
async def startup_event():
global model
model = load_model()
if torch.cuda.is_available():
model = model.cuda()
@app.post("/process_hdr/")
async def process_hdr(files: List[UploadFile] = File(...)):
"""
Process multiple LDR images to create an HDR image.
Expects a list of LDR images as input.
Returns the processed HDR image.
"""
if not files:
return {"error": "No files provided"}
try:
# Create temporary directory for processing
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir)
# Save uploaded files
input_paths = []
for file in files:
file_path = temp_dir / file.filename
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
input_paths.append(str(file_path))
# Read images
images = []
for path in input_paths:
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img is None:
return {"error": f"Failed to read image: {path}"}
# Convert to RGB
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Normalize to [0, 1]
if img.dtype == np.uint16:
img = img.astype(np.float32) / 65535.0
elif img.dtype == np.uint8:
img = img.astype(np.float32) / 255.0
images.append(img)
# Process images
# Stack images for batch processing
input_tensor = torch.from_numpy(np.stack(images)).permute(0, 3, 1, 2).float()
if torch.cuda.is_available():
input_tensor = input_tensor.cuda()
# Process through model
with torch.no_grad():
output = model(input_tensor)
# Convert output to numpy
output = output.cpu().numpy()
output = np.transpose(output[0], (1, 2, 0))
# Save result
result_path = temp_dir / "result.hdr"
cv2.imwrite(str(result_path), cv2.cvtColor(output, cv2.COLOR_RGB2BGR))
# Return the HDR file
return FileResponse(
str(result_path),
media_type="image/vnd.radiance",
filename="result.hdr"
)
except Exception as e:
return {"error": str(e)}
@app.get("/health")
async def health_check():
"""
Simple health check endpoint
"""
return {"status": "healthy", "model_loaded": model is not None}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)