summaryrefslogtreecommitdiff
blob: 72ab792b2278f9ef399d9d3da447cd3626e485b8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
Disables aotriton download when both USE_FLASH_ATTENTION and USE_MEM_EFF_ATTENTION cmake flags are OFF
Backports upstream PR to 2.3.0: https://github.com/pytorch/pytorch/pull/130197
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
@@ -24,7 +24,7 @@
 #include <c10/core/SymInt.h>
 #include <c10/util/string_view.h>
 
-#if USE_ROCM
+#if defined(USE_ROCM) && (defined(USE_MEM_EFF_ATTENTION) || defined(USE_FLASH_ATTENTION))
 #include <aotriton/flash.h>
 #endif
 
@@ -207,7 +207,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
   // Check that the gpu is capable of running flash attention
   using sm80 = SMVersion<8, 0>;
   using sm90 = SMVersion<9, 0>;
-#if USE_ROCM
+#if defined(USE_ROCM) && defined(USE_FLASH_ATTENTION)
   auto stream = at::cuda::getCurrentCUDAStream().stream();
   if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
       auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -238,7 +238,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
   // Mem Efficient attention supports hardware in the range [sm_50, sm_90]
   using sm50 = SMVersion<5, 0>;
   using sm90 = SMVersion<9, 0>;
-#if USE_ROCM
+#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION)
   auto stream = at::cuda::getCurrentCUDAStream().stream();
   if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
       auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -623,7 +623,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
       array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
   constexpr auto less_than_sm80_mem_efficient_dtypes =
       array_of<at::ScalarType>(at::kHalf, at::kFloat);
-#ifdef USE_ROCM
+#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION)
   constexpr auto aotriton_mem_efficient_dtypes =
       array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
 #endif
@@ -668,7 +668,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
     }
   }
 
-#ifdef USE_ROCM
+#if defined(USE_ROCM) && defined(USE_MEM_EFF_ATTENTION)
   return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug);
 #else
   auto dprop = at::cuda::getCurrentDeviceProperties();
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1095,10 +1095,12 @@ if(USE_ROCM)
       message(STATUS "Disabling Kernel Assert for ROCm")
     endif()
 
-    include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
     if(USE_CUDA)
       caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
     endif()
+    if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
+      include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
+    endif()
   else()
     caffe2_update_option(USE_ROCM OFF)
   endif()