# Copyright 2025 Gentoo Authors
# Distributed under the terms of the GNU General Public License v2

EAPI=8

DISTUTILS_SINGLE_IMPL=1
DISTUTILS_EXT=1
PYTHON_COMPAT=( python3_{10..13} )
DISTUTILS_USE_PEP517=setuptools

inherit cuda distutils-r1 pypi

DESCRIPTION="Flash Attention: Fast and Memory-Efficient Exact Attention (Python component)."
HOMEPAGE="https://github.com/Dao-AILab/flash-attention"
IUSE="cuda rocm"

LICENSE="BSD"
SLOT="0"
KEYWORDS="~amd64"
REQUIRED_USE="${PYTHON_REQUIRED_USE}"

DEPEND="dev-libs/cutlass"
# shellcheck disable=SC2016
RDEPEND="${PYTHON_DEPS}
	$(python_gen_cond_dep '
		sci-ml/einops[${PYTHON_USEDEP}]
	')
	cuda? ( sci-ml/caffe2[cuda,flash] )
	rocm? ( sci-ml/caffe2[rocm] )
	sci-ml/pytorch[${PYTHON_SINGLE_USEDEP}]"
DEPEND="${RDEPEND}"
# shellcheck disable=SC2016
BDEPEND="${PYTHON_DEPS}
	dev-build/ninja
	$(python_gen_cond_dep '
		dev-python/psutil[${PYTHON_USEDEP}]
	')"

PATCHES=(
	"${FILESDIR}/${P}-respect-flags.patch"
)

pkg_setup() {
	if use cuda; then
		if [[ -z "${NVCC_PREPEND_FLAGS}" ]] || ! grep -q -- "--threads" <<< "${NVCC_PREPEND_FLAGS}"; then
			ewarn
			ewarn "If this hangs your system, try adding '--threads 1' to NVCC_PREPEND_FLAGS. Or try"
			ewarn "lowering the number of make jobs. Example:"
			ewarn
			ewarn "  mkdir -p /etc/portage/env /etc/portage/package.env"
			ewarn "  echo 'NVCC_PREPEND_FLAGS=\"--threads 1\"' > /etc/portage/env/flash-attn"
			ewarn "  echo 'sci-ml/flash-attn flash-attn' >> /etc/portage/package.env/flash-attn"
			ewarn
		fi
	fi
	python-single-r1_pkg_setup
}

src_prepare() {
	cuda_src_prepare
	distutils-r1_src_prepare
}

src_compile() {
	export FLASH_ATTENTION_FORCE_BUILD=TRUE "MAX_JOBS=$(makeopts_jobs)"
	if use cuda; then
		cuda_add_sandbox
		export BUILD_TARGET=cuda
	fi
	use rocm && export BUILD_TARGET=rocm
	distutils-r1_src_compile
}