Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Docker env and web demo #161

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@
!video/moon_straight-line.mp4
!video/moon.mp4
!video/moon_circle.mp4
!misc/moon_40.gif
!misc/moon_40.gif
!cog.yaml
!predict.py
1 change: 1 addition & 0 deletions BoostingMonocularDepth
Submodule BoostingMonocularDepth added at ecedd0
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# [CVPR 2020] 3D Photography using Context-aware Layered Depth Inpainting

[Demo and Docker image on Replicate](https://replicate.com/vt-vl-lab/3d-photo-inpainting)

<a href="https://replicate.com/vt-vl-lab/3d-photo-inpainting"><img src="https://replicate.com/vt-vl-lab/3d-photo-inpainting/badge"></a>

[![Open 3DPhotoInpainting in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1706ToQrkIZshRSJSHvZ1RuCiM__YX3Bz)

### [[Paper](https://arxiv.org/abs/2004.04727)] [[Project Website](https://shihmengli.github.io/3D-Photo-Inpainting/)] [[Google Colab](https://colab.research.google.com/drive/1706ToQrkIZshRSJSHvZ1RuCiM__YX3Bz)]
Expand Down
34 changes: 34 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md

build:
# set to true if your model requires a GPU
gpu: true
cuda: "11.4"

# python version in the form '3.8' or '3.8.12'
python_version: "3.8"

# a list of packages in the format <package-name===<version>
python_packages:

- "ipython==7.33.0"
- "vispy==0.6.4"
- "moviepy==1.0.2"
- "transforms3d==0.3.1"
- "networkx==2.3"
- "pyyaml==5.4.1"
- "torch==1.9.0"
- "torchvision==0.10.0"

# commands run after the environment is setup
run:
- "apt-get update && apt-get install sed"
- "apt-get install -y mesa-utils-extra libegl1-mesa-dev libgles2-mesa-dev xvfb"
- "apt-get install -y libsm6 libxrender1"
- "pip install scipy matplotlib scikit-image"
- "pip install pyqt5 pyopengl"
- "pip install opencv-python-headless"

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
236 changes: 236 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
''' Cog interface for 3D photo inpainting'''
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md

import argparse
import copy
import glob
import os
import shlex
import subprocess
import sys
import time
from functools import partial

import cv2
import imageio
import numpy as np
import scipy.misc as misc
import torch
import vispy
import yaml
from cog import BasePredictor, Input, Path
from skimage.transform import resize
from tqdm import tqdm
from PIL import Image
from torchvision.transforms.functional import to_pil_image

import MiDaS.MiDaS_utils as MiDaS_utils
from bilateral_filtering import sparse_bilateral_filtering
from boostmonodepth_utils import run_boostmonodepth
from mesh import output_3d_photo, read_ply, write_ply
from MiDaS.monodepth_net import MonoDepthNet
from MiDaS.run import run_depth
from networks import Inpaint_Color_Net, Inpaint_Depth_Net, Inpaint_Edge_Net
from utils import get_MiDaS_samples, read_MiDaS_depth

cmd = "Xvfb :0 -screen 0 1024x768x24 -ac +extension GLX +render -noreset"
subprocess.Popen(shlex.split(cmd))

import os

os.environ["DISPLAY"] = ":0"

from mesh import Canvas_view


class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""

pass

def predict(
self,
image_path: Path = Input(description="Input image"),
effect: str = Input(
description="Video animation effect", choices=["dolly-zoom-in", "zoom-in", "circle", "swing"]
),
) -> Path:

# set configs
config = yaml.load(open("argument.yml", "r"))
# set for headless rendering
config["offscreen_rendering"] = True
config["src_folder"] = 'input'
os.makedirs(config["mesh_folder"], exist_ok=True)
os.makedirs(config["video_folder"], exist_ok=True)
os.makedirs(config["depth_folder"], exist_ok=True)

if config["offscreen_rendering"] is True:
vispy.use(app="egl")
config["video_postfix"] = [effect]

# save image as input/image.jpg
from PIL import Image
from torchvision.transforms.functional import to_pil_image
im = Image.open(str(image_path)).convert("RGB")
print('Saving input image to input/image.jpg...')
im.save('input/image.jpg')

# select trajectory type, shift range based on chosen video postfix effect
traj_types_dict = {"dolly-zoom-in": "double-straight-line",
'zoom-in': 'double-straight-line',
'circle': 'circle',
'swing': 'circle'}

shift_range_dict = {"dolly-zoom-in": [[0.00], [0.00], [-0.05]],
"zoom-in": [[0.00], [0.00], [-0.05]],
"circle": [[-0.015], [-0.015], [-0.05]],
"swing": [[-0.015], [-0.00], [-0.05]]}

config["traj_types"] = [traj_types_dict[effect]]
config["x_shift_range"], config["y_shift_range"], config["z_shift_range"] = shift_range_dict[effect]

sample_list = get_MiDaS_samples(config["src_folder"], config["depth_folder"], config, config["specific"])
normal_canvas, all_canvas = None, None

if isinstance(config["gpu_ids"], int) and (config["gpu_ids"] >= 0):
device = config["gpu_ids"]
else:
device = "cpu"

print(f"running on device {device}")

for idx in tqdm(range(len(sample_list))):
depth = None
sample = sample_list[idx]
print("Current Source ==> ", sample["src_pair_name"])
mesh_fi = os.path.join(config["mesh_folder"], sample["src_pair_name"] + ".ply")
image = imageio.imread(sample["ref_img_fi"])

print(f"Running depth extraction at {time.time()}")
if config["use_boostmonodepth"] is True:
run_boostmonodepth(sample["ref_img_fi"], config["src_folder"], config["depth_folder"])
elif config["require_midas"] is True:
run_depth(
[sample["ref_img_fi"]],
config["src_folder"],
config["depth_folder"],
config["MiDaS_model_ckpt"],
MonoDepthNet,
MiDaS_utils,
target_w=640,
)

if "npy" in config["depth_format"]:
config["output_h"], config["output_w"] = np.load(sample["depth_fi"]).shape[:2]
else:
config["output_h"], config["output_w"] = imageio.imread(sample["depth_fi"]).shape[:2]
frac = config["longer_side_len"] / max(config["output_h"], config["output_w"])
config["output_h"], config["output_w"] = int(config["output_h"] * frac), int(config["output_w"] * frac)
config["original_h"], config["original_w"] = config["output_h"], config["output_w"]
if image.ndim == 2:
image = image[..., None].repeat(3, -1)
if (
np.sum(np.abs(image[..., 0] - image[..., 1])) == 0
and np.sum(np.abs(image[..., 1] - image[..., 2])) == 0
):
config["gray_image"] = True
else:
config["gray_image"] = False
image = cv2.resize(image, (config["output_w"], config["output_h"]), interpolation=cv2.INTER_AREA)
depth = read_MiDaS_depth(sample["depth_fi"], 3.0, config["output_h"], config["output_w"])
mean_loc_depth = depth[depth.shape[0] // 2, depth.shape[1] // 2]
if not (config["load_ply"] is True and os.path.exists(mesh_fi)):
vis_photos, vis_depths = sparse_bilateral_filtering(
depth.copy(), image.copy(), config, num_iter=config["sparse_iter"], spdb=False
)
depth = vis_depths[-1]
model = None
torch.cuda.empty_cache()
print("Start Running 3D_Photo ...")
print(f"Loading edge model at {time.time()}")
depth_edge_model = Inpaint_Edge_Net(init_weights=True)
depth_edge_weight = torch.load(config["depth_edge_model_ckpt"], map_location=torch.device(device))
depth_edge_model.load_state_dict(depth_edge_weight)
depth_edge_model = depth_edge_model.to(device)
depth_edge_model.eval()

print(f"Loading depth model at {time.time()}")
depth_feat_model = Inpaint_Depth_Net()
depth_feat_weight = torch.load(config["depth_feat_model_ckpt"], map_location=torch.device(device))
depth_feat_model.load_state_dict(depth_feat_weight, strict=True)
depth_feat_model = depth_feat_model.to(device)
depth_feat_model.eval()
depth_feat_model = depth_feat_model.to(device)
print(f"Loading rgb model at {time.time()}")
rgb_model = Inpaint_Color_Net()
rgb_feat_weight = torch.load(config["rgb_feat_model_ckpt"], map_location=torch.device(device))
rgb_model.load_state_dict(rgb_feat_weight)
rgb_model.eval()
rgb_model = rgb_model.to(device)
graph = None

print(f"Writing depth ply (and basically doing everything) at {time.time()}")
rt_info = write_ply(
image,
depth,
sample["int_mtx"],
mesh_fi,
config,
rgb_model,
depth_edge_model,
depth_edge_model,
depth_feat_model,
)

if rt_info is False:
continue
rgb_model = None
color_feat_model = None
depth_edge_model = None
depth_feat_model = None
torch.cuda.empty_cache()
if config["save_ply"] is True or config["load_ply"] is True:
verts, colors, faces, Height, Width, hFov, vFov = read_ply(mesh_fi)
else:
verts, colors, faces, Height, Width, hFov, vFov = rt_info

print(f"Making video at {time.time()}")
videos_poses, video_basename = copy.deepcopy(sample["tgts_poses"]), sample["tgt_name"]
top = config.get("original_h") // 2 - sample["int_mtx"][1, 2] * config["output_h"]
left = config.get("original_w") // 2 - sample["int_mtx"][0, 2] * config["output_w"]
down, right = top + config["output_h"], left + config["output_w"]
border = [int(xx) for xx in [top, down, left, right]]

output_path = os.path.join(config["video_folder"], video_basename[0] + '_' + effect + '.mp4')
normal_canvas, all_canvas = output_3d_photo(
verts.copy(),
colors.copy(),
faces.copy(),
copy.deepcopy(Height),
copy.deepcopy(Width),
copy.deepcopy(hFov),
copy.deepcopy(vFov),
copy.deepcopy(sample["tgt_pose"]),
sample["video_postfix"],
copy.deepcopy(sample["ref_pose"]),
copy.deepcopy(config["video_folder"]),
image.copy(),
copy.deepcopy(sample["int_mtx"]),
config,
image,
videos_poses,
video_basename,
config.get("original_h"),
config.get("original_w"),
border=border,
depth=depth,
normal_canvas=normal_canvas,
all_canvas=all_canvas,
mean_loc_depth=mean_loc_depth,
)

print(f'Done. Saving to output path: {output_path}')
return Path(str(output_path))
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ def get_MiDaS_samples(image_folder, depth_folder, config, specific=None, aft_cer
sdict['src_pair_name'] = sdict['tgt_name'][0]

return samples

def get_valid_size(imap):
x_max = np.where(imap.sum(1).squeeze() > 0)[0].max() + 1
x_min = np.where(imap.sum(1).squeeze() > 0)[0].min()
Expand Down