From ed0f9cdcf56618f43311829c7b21c4654bad5926 Mon Sep 17 00:00:00 2001 From: nscipione Date: Fri, 10 Jan 2025 14:57:56 +0000 Subject: [PATCH 1/3] Update header for SYCL complex implementation and fix issue with multi_ptr cast Newest compiler nightly change include path for header, this change requires to update to properly check and include sycl complex headers. multi_ptr cast as it is caused ambigous call to static_cast, this patch work-around the issue overloading load function using raw_pointers. Signed-off-by: nscipione --- onemath/sycl/blas/include/blas_meta.h | 6 ++-- .../blas3/gemm_load_store_complex.hpp | 14 ++++++++++ .../blas/src/operations/blas3/gemm_local.hpp | 28 +++++++++++++++---- 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/onemath/sycl/blas/include/blas_meta.h b/onemath/sycl/blas/include/blas_meta.h index 9e813da..03b19c6 100644 --- a/onemath/sycl/blas/include/blas_meta.h +++ b/onemath/sycl/blas/include/blas_meta.h @@ -30,10 +30,10 @@ #ifdef BLAS_ENABLE_COMPLEX #define SYCL_EXT_ONEAPI_COMPLEX #include -#if __has_include() -#include +#if __has_include() +#include #else -#include +#include #endif #endif diff --git a/onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp b/onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp index 61977df..11934b3 100644 --- a/onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp +++ b/onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp @@ -116,12 +116,26 @@ class vec_complex { m_Data = *(Ptr + Offset * NumElements); } + // Load + template + void load(size_t Offset, + const DataT* Ptr) { + m_Data = *(Ptr + Offset * NumElements); + } + // Store template void store(size_t Offset, sycl::multi_ptr Ptr) const { *(Ptr + Offset * NumElements) = m_Data; } + + // Store + template + void store(size_t Offset, + DataT* Ptr) const { + *(Ptr + Offset * NumElements) = m_Data; + } }; /*! @brief Partial specialization of the Packetize class dedicated to diff --git a/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp b/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp index c4c2165..aa5f053 100644 --- a/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp +++ b/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp @@ -527,12 +527,28 @@ class Gemm( - 0, sycl::multi_ptr(reg)); - out_vec *= alpha_; - - out_vec.template store( - 0, sycl::multi_ptr(out_ptr)); + if constexpr (std::is_same_v< + element_t, + sycl::ext::oneapi::experimental::complex> || + std::is_same_v< + element_t, + sycl::ext::oneapi::experimental::complex>) { + out_vec.template load(0, reg); + out_vec *= alpha_; + + out_vec.template store( + 0, out_ptr); + } else { + out_vec.template load( + 0, sycl::multi_ptr(reg)); + out_vec *= alpha_; + + out_vec.template store( + 0, sycl::multi_ptr(out_ptr)); + } } /*! * @brief Store the computed gemm result to the C matrix From f30d129b09f79fcb5a85f3fa9b58e7064ac3f64b Mon Sep 17 00:00:00 2001 From: nscipione Date: Fri, 10 Jan 2025 15:18:51 +0000 Subject: [PATCH 2/3] Formatting Signed-off-by: nscipione --- .../blas/src/operations/blas3/gemm_load_store_complex.hpp | 6 ++---- onemath/sycl/blas/src/operations/blas3/gemm_local.hpp | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp b/onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp index 11934b3..291fc16 100644 --- a/onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp +++ b/onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp @@ -118,8 +118,7 @@ class vec_complex { // Load template - void load(size_t Offset, - const DataT* Ptr) { + void load(size_t Offset, const DataT *Ptr) { m_Data = *(Ptr + Offset * NumElements); } @@ -132,8 +131,7 @@ class vec_complex { // Store template - void store(size_t Offset, - DataT* Ptr) const { + void store(size_t Offset, DataT *Ptr) const { *(Ptr + Offset * NumElements) = m_Data; } }; diff --git a/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp b/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp index aa5f053..af01e97 100644 --- a/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp +++ b/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp @@ -537,8 +537,8 @@ class Gemm(0, reg); out_vec *= alpha_; - out_vec.template store( - 0, out_ptr); + out_vec.template store(0, out_ptr); } else { out_vec.template load( From f367f9261a84d2995e9de96da6c780534bb2024c Mon Sep 17 00:00:00 2001 From: nscipione Date: Mon, 13 Jan 2025 10:12:48 +0000 Subject: [PATCH 3/3] Adding comment on why the conditional statement is necessary Signed-off-by: nscipione --- onemath/sycl/blas/src/operations/blas3/gemm_local.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp b/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp index af01e97..356da25 100644 --- a/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp +++ b/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp @@ -527,6 +527,9 @@ class Gemm> ||