#!/bin/bash
echo "[pre-start] Starting Trellis2 setup..."

# pymeshlab + Qt5 fix
pip install pymeshlab --quiet --root-user-action=ignore 2>/dev/null
PMLAB_LIB="/usr/local/lib/python3.12/site-packages/pymeshlab/lib"
if [ -d "$PMLAB_LIB" ]; then
    cp -f "$PMLAB_LIB"/libQt5*.so.5 /usr/lib64/
    cp -f "$PMLAB_LIB"/libicu*.so.* /usr/lib64/ 2>/dev/null
    echo "[pre-start] Qt5 libs copied from pymeshlab."
fi

# Trellis2 patches for Turing GPU (SM 7.5) - bf16 -> fp16
sed -i "s/'bf16': torch.bfloat16/'bf16': torch.float16/" /root/ComfyUI/custom_nodes/ComfyUI-Trellis2/trellis2/modules/utils.py
sed -i "s/'bfloat16': torch.bfloat16/'bfloat16': torch.float16/" /root/ComfyUI/custom_nodes/ComfyUI-Trellis2/trellis2/modules/utils.py

# spconv backend (flex_gemm requires SM 8.0+)
pip install cumm-cu120 spconv-cu118 --quiet --root-user-action=ignore

# Trellis2 dependencies
pip install plyfile zstandard --quiet --root-user-action=ignore 2>/dev/null
pip install --no-deps /root/ComfyUI/custom_nodes/ComfyUI-Trellis2/wheels/Linux/Torch291/*.whl --quiet --root-user-action=ignore 2>/dev/null
pip install meshlib scipy open3d plotly rembg --quiet --root-user-action=ignore 2>/dev/null
pip install "numpy<2.3.0,>=2.2" --quiet --root-user-action=ignore 2>/dev/null

# OpenCV fix - MUST BE AFTER ALL pip installs (open3d etc. pull in opencv-python)
pip uninstall opencv-python opencv-python-headless opencv-contrib-python opencv-contrib-python-headless -y --quiet --root-user-action=ignore 2>/dev/null
pip install opencv-contrib-python-headless --quiet --no-cache-dir --root-user-action=ignore 2>/dev/null

echo "[pre-start] Dependencies installed."

# =============================================================================
# SM 7.5 (Turing) CUDA kernel overrides for Trellis2
# =============================================================================
# The prebuilt Torch291 wheels from ComfyUI-Trellis2 ship CUDA kernels compiled
# only for SM 8.0+ (Ampere) or SM 12.0 (Blackwell). On Turing GPUs (RTX 6000,
# 2080 Ti, etc. = SM 7.5) this causes "no kernel image available" crashes
# during the Shape Slat decoder phase.
#
# These .so files were rebuilt from source with TORCH_CUDA_ARCH_LIST="7.5".
# The original Python wrapper files from the Torch291 wheels are kept intact —
# only the compiled CUDA .so binaries are replaced.
#
# HOW TO REBUILD (inside the container, e.g. podman exec -it comfyui bash):
#
#   nvdiffrast (provides _nvdiffrast_c.so):
#     git clone https://github.com/NVlabs/nvdiffrast.git /tmp/nvdiffrast_src
#     cd /tmp/nvdiffrast_src
#     TORCH_CUDA_ARCH_LIST="7.5" pip install . --force-reinstall --no-deps --no-build-isolation
#
#   cumesh (provides _C.so + _cubvh.so):
#     git clone --recursive https://github.com/JeffreyXiang/CuMesh.git /tmp/CuMesh
#     cd /tmp/CuMesh
#     # Patch setup.py Linux nvcc flags — add after "-std=c++17":
#     #   "--expt-relaxed-constexpr", "--extended-lambda"
#     # (needed for cubvh submodule lambda expressions)
#     TORCH_CUDA_ARCH_LIST="7.5" pip install . --force-reinstall --no-deps --no-build-isolation
#
#   nvdiffrec_render (provides renderutils/_C.so):
#     git clone https://github.com/NVlabs/nvdiffrec.git /tmp/nvdiffrec
#     cd /tmp/nvdiffrec/render/renderutils
#     # Create a mini setup.py with CUDAExtension for c_src/*.cpp + c_src/*.cu
#     # nvcc flags: -O3 -std=c++17 --extended-lambda --expt-relaxed-constexpr
#     TORCH_CUDA_ARCH_LIST="7.5" python3 setup.py build_ext --inplace
#     cp _C*.so /usr/local/lib64/python3.12/site-packages/nvdiffrec_render/renderutils/
#
#   o_voxel (provides _C.so):
#     git clone --depth 1 https://github.com/microsoft/TRELLIS.2.git /tmp/TRELLIS.2
#     cd /tmp/TRELLIS.2/o-voxel
#     git clone --depth 1 https://gitlab.com/libeigen/eigen.git third_party/eigen
#     TORCH_CUDA_ARCH_LIST="7.5" pip install . --force-reinstall --no-deps --no-build-isolation
#     # IMPORTANT: This builds an older o_voxel version missing some Python files
#     # (e.g. tiled_flexible_dual_grid_to_mesh). Therefore reinstall the original
#     # wheel FIRST, then overwrite ONLY the .so — do not replace Python files!
#
#   flex_gemm CUDA kernels (provides kernels/cuda.cpython-312-*.so):
#     The flex_gemm wheel ships .cu source but NOT the header files (.h/.cuh).
#     Headers were reconstructed from function signatures in the .cu files +
#     copied from the o_voxel source (hash/api.h, hash/hash.cuh are identical).
#
#     cd /tmp && mkdir -p flex_build && cd flex_build
#     FGSRC=/usr/local/lib64/python3.12/site-packages/flex_gemm/kernels/cuda
#     OVSRC=/tmp/TRELLIS.2/o-voxel/src
#
#     # Copy hash headers from o_voxel (same codebase)
#     cp $OVSRC/hash/api.h $FGSRC/hash/api.h
#     cp $OVSRC/hash/hash.cuh $FGSRC/hash/hash.cuh
#
#     # Create grid_sample/api.h with correct signatures (note: Tensor& not Tensor)
#     cat > $FGSRC/grid_sample/api.h << 'EOF'
#     #pragma once
#     #include <torch/extension.h>
#     torch::Tensor hashmap_build_grid_sample_3d_nearest_neighbor_map(
#         torch::Tensor& hashmap_keys, torch::Tensor& hashmap_vals,
#         const torch::Tensor& coords, const torch::Tensor& grid,
#         const int W, const int H, const int D);
#     std::tuple<torch::Tensor, torch::Tensor> hashmap_build_grid_sample_3d_trilinear_neighbor_map_weight(
#         torch::Tensor& hashmap_keys, torch::Tensor& hashmap_vals,
#         const torch::Tensor& coords, const torch::Tensor& grid,
#         const int W, const int H, const int D);
#     EOF
#
#     # Create grid_sample/grid_sample.h (minimal, just needed for include)
#     echo -e '#pragma once\n#include <torch/extension.h>' > $FGSRC/grid_sample/grid_sample.h
#
#     # Create spconv/api.h with correct signatures
#     cat > $FGSRC/spconv/api.h << 'EOF'
#     #pragma once
#     #include <torch/extension.h>
#     torch::Tensor hashmap_build_submanifold_conv_neighbour_map_cuda(
#         torch::Tensor& hashmap_keys, torch::Tensor& hashmap_vals,
#         const torch::Tensor& coords,
#         int W, int H, int D, int Kw, int Kh, int Kd, int Dw, int Dh, int Dd);
#     std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
#         neighbor_map_post_process_for_masked_implicit_gemm_1(const torch::Tensor& neighbor_map);
#     std::tuple<torch::Tensor, torch::Tensor>
#         neighbor_map_post_process_for_masked_implicit_gemm_2(
#         const torch::Tensor& gray_code, const torch::Tensor& sorted_idx, int block_size);
#     EOF
#
#     # Create spconv/neighbor_map.h (minimal)
#     echo -e '#pragma once\n#include <torch/extension.h>' > $FGSRC/spconv/neighbor_map.h
#
#     # Build with setup.py:
#     cat > setup.py << 'EOF'
#     from setuptools import setup
#     from torch.utils.cpp_extension import CUDAExtension, BuildExtension
#     FGSRC = "/usr/local/lib64/python3.12/site-packages/flex_gemm/kernels/cuda"
#     setup(
#         name="flex_gemm_cuda",
#         ext_modules=[CUDAExtension("cuda", sources=[
#             f"{FGSRC}/ext.cpp", f"{FGSRC}/hash/hash.cu",
#             f"{FGSRC}/grid_sample/grid_sample.cu", f"{FGSRC}/spconv/neighbor_map.cu",
#         ], include_dirs=[FGSRC], extra_compile_args={
#             "cxx": ["-O3", "-std=c++17"],
#             "nvcc": ["-O3", "-std=c++17", "--expt-relaxed-constexpr", "--extended-lambda"],
#         })],
#         cmdclass={"build_ext": BuildExtension},
#     )
#     EOF
#     TORCH_CUDA_ARCH_LIST="7.5" python3 setup.py build_ext --inplace
#
# After building, copy .so files to storage/sm75-overrides/ on the host:
#   podman cp comfyui:/usr/local/lib64/python3.12/site-packages/o_voxel/_C.cpython-312-x86_64-linux-gnu.so storage/sm75-overrides/o_voxel_C.so
#   podman cp comfyui:/usr/local/lib64/python3.12/site-packages/cumesh/_C.cpython-312-x86_64-linux-gnu.so storage/sm75-overrides/cumesh_C.so
#   podman cp comfyui:/usr/local/lib64/python3.12/site-packages/cumesh/_cubvh.cpython-312-x86_64-linux-gnu.so storage/sm75-overrides/cumesh_cubvh.so
#   podman cp comfyui:/usr/local/lib64/python3.12/site-packages/nvdiffrec_render/renderutils/_C.cpython-312-x86_64-linux-gnu.so storage/sm75-overrides/nvdiffrec_render_C.so
#   podman cp comfyui:/usr/local/lib64/python3.12/site-packages/_nvdiffrast_c.cpython-312-x86_64-linux-gnu.so storage/sm75-overrides/nvdiffrast_c.so
#   podman cp comfyui:/usr/local/lib64/python3.12/site-packages/flex_gemm/kernels/cuda.cpython-312-x86_64-linux-gnu.so storage/sm75-overrides/flex_gemm_cuda.so
#
# Mount in podman-compose.yaml:
#   - ./storage/sm75-overrides:/run/sm75-overrides:ro
# =============================================================================
if [ -d /run/sm75-overrides ]; then
    echo "[pre-start] Applying SM 7.5 CUDA kernel overrides..."
    SITE=/usr/local/lib64/python3.12/site-packages
    cp -f /run/sm75-overrides/o_voxel_C.so "$SITE/o_voxel/_C.cpython-312-x86_64-linux-gnu.so"
    cp -f /run/sm75-overrides/cumesh_C.so "$SITE/cumesh/_C.cpython-312-x86_64-linux-gnu.so"
    cp -f /run/sm75-overrides/cumesh_cubvh.so "$SITE/cumesh/_cubvh.cpython-312-x86_64-linux-gnu.so"
    cp -f /run/sm75-overrides/nvdiffrec_render_C.so "$SITE/nvdiffrec_render/renderutils/_C.cpython-312-x86_64-linux-gnu.so"
    cp -f /run/sm75-overrides/nvdiffrast_c.so "$SITE/_nvdiffrast_c.cpython-312-x86_64-linux-gnu.so"
    cp -f /run/sm75-overrides/flex_gemm_cuda.so "$SITE/flex_gemm/kernels/cuda.cpython-312-x86_64-linux-gnu.so"
    echo "[pre-start] SM 7.5 overrides applied (6 .so files)."
fi

echo "[pre-start] Done."
