1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
|
Disables aotriton download when both USE_FLASH_ATTENTION and USE_MEM_EFF_ATTENTION cmake flags are OFF
Backports upstream PR to 2.3.0: https://github.com/pytorch/pytorch/pull/130197
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
@@ -24,7 +24,7 @@
#include <c10/core/SymInt.h>
#include <c10/util/string_view.h>
-#if USE_ROCM
+#if defined(USE_ROCM) && (defined(USE_MEM_EFF_ATTENTION) || defined(USE_FLASH_ATTENTION))
#include <aotriton/flash.h>
#endif
@@ -207,7 +207,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
// Check that the gpu is capable of running flash attention
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
-#if USE_ROCM
+#if defined(USE_ROCM) && defined(USE_FLASH_ATTENTION)
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -238,7 +238,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
using sm50 = SMVersion<5, 0>;
using sm90 = SMVersion<9, 0>;
-#if USE_ROCM
+#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION)
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -623,7 +623,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
constexpr auto less_than_sm80_mem_efficient_dtypes =
array_of<at::ScalarType>(at::kHalf, at::kFloat);
-#ifdef USE_ROCM
+#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION)
constexpr auto aotriton_mem_efficient_dtypes =
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
#endif
@@ -668,7 +668,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
}
}
-#ifdef USE_ROCM
+#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION)
return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug);
#else
auto dprop = at::cuda::getCurrentDeviceProperties();
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1095,10 +1095,12 @@ if(USE_ROCM)
message(STATUS "Disabling Kernel Assert for ROCm")
endif()
- include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
if(USE_CUDA)
caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
endif()
+ if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
+ include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
+ endif()
else()
caffe2_update_option(USE_ROCM OFF)
endif()
|