Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement batched serial gbtrf #2489

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from

Conversation

yasahi-hpc
Copy link
Contributor

This PR implements gbtrf function.

Following files are added:

  1. KokkosBatched_Gbtrf_Serial_Impl.hpp: Internal interfaces
  2. KokkosBatched_Gbtrf_Serial_Internal.hpp: Implementation details
  3. KokkosBatched_Gbtrf.hpp: APIs
  4. Test_Batched_SerialGbtrf.hpp: Unit tests for that

Detailed description

It computes an LU factorization of a real general M-by-N band matrix A using partial pivoting with row interchanges.
Here, the matrix has the following shape.

  • A: (batch_count, ldab, n)
    On entry, the matrix A in band storage. M-by-N matrix to be factored. On exit, the factors L and U from the factorization where U is stored as an upper triangular band matrix with KL+KU superdiagonals in rows 0 to KL+KU,
    and the multipliers used during the factorization are stored in rows KL+KU+1 to 2*KL+KU.
  • IPIV: (batch_count, min(m, n))
    The pivot indices; for 0 <= i < min(M,N), row i of the matrix was interchanged with row IPIV(i).
  • kl: The number of subdiagonals within the band of A. kl >= 0
  • ku: The number of superdiagonals within the band of A. ku >= 0
  • m: The number of rows of the matrix A. (optional)

Parallelization would be made in the following manner. This is efficient only when
A is given in LayoutLeft for GPUs and LayoutRight for CPUs (parallelized over batch direction).

Kokkos::parallel_for('gbtrf', 
    Kokkos::RangePolicy<execution_space> policy(0, n),
    [=](const int k) {
        auto aa = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL());
        auto ipiv = Kokkos::subview(m_ipiv, k, Kokkos::ALL());

        KokkosBatched::SerialGbtrf<AlgoTagType>::invoke(aa, ipiv, kl, ku);
    });

Tests

  1. Make a random band matrix from random A and copy it to LU. Represent A in band storage AB and factorize it with gbtrf. Then, convert AB back into full storage A and extract L and U. Make a reference by getrf to get reference L and U from LU matrix. Finally, we confirm L and U are the same.
  2. Simple and small analytical test, i.e. choose A as follows to confirm LUB is factorized as expected.
A = [[1. -3. -2. 0.],
     [-1. 1 -3 -2],
     [2. -1. 1. -3],
     [0. 2. -1. 1.]]
LUB: [[0,       0,    0,    0],
      [0,       0,    0,   -3],
      [0,       0,    1,  1.5],
      [0,      -1, -2.5, -3.2],
      [2,    -2.5,   -3,  5.4],
      [-0.5, -0.2,    1,    0],
      [0.5,  -0.8,    0,    0]]
piv = [2 2 2 3]

@cwpearson cwpearson added the AT2-CI-APPROVAL Approve CI to run at SNL label Jan 28, 2025
@yasahi-hpc yasahi-hpc force-pushed the implement-batched-serial-gbtrf branch 2 times, most recently from 2723819 to 507b3bd Compare February 6, 2025 07:40
@lucbv lucbv self-requested a review February 19, 2025 02:28
@yasahi-hpc yasahi-hpc force-pushed the implement-batched-serial-gbtrf branch from 507b3bd to 1d0c1e2 Compare February 27, 2025 14:43
@yasahi-hpc yasahi-hpc added AT2-CI-APPROVAL Approve CI to run at SNL and removed AT2-CI-APPROVAL Approve CI to run at SNL labels Mar 7, 2025
Copy link
Contributor

@lucbv lucbv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small clean-up needed but nothing major

auto h_NL1 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), NL1);
auto h_NL2 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), NL2);

RealType eps = 1.0e1 * ats::epsilon();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like an arbitrary number that happens to work... how about doing an error analysis to compute the number of round off operations performed in gbtrf?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please detail this point?
The tolerance is numerical precision of fp32 or fp64 multiplied by 10.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example when you perform a gemv operation:

y = beta * y + alpha * A * x

for each value y(i) you have performed numCols multiplications and numCols - 1 additions to compute A * x, then there is two more multiplications for alpha and beta and one more addition between beta * y and alpha * A * x so in total that's numCols + 2 multiplications and numCols additions. So a check might look like this

tol = (2 * numCols + 2) * maxVal * Kokkos::ArithTraits<Scalar>::eps()
Kokkos::abs(y(i) - y_ref(i)) < tol

the maxVal is the maximum value an input can take as catastrophic cancelation could happen

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your point.
The error analysis would be critical, when it comes to fp16.

Can I start from simpler cases like gemv and gemm.

auto h_NL_ref = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), NL_ref);
auto h_ipiv_ref = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), ipiv_ref);

RealType eps = 1.0e3 * ats::epsilon();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one looks even more arbitrary than the previous one above : (

Yuuichi Asahi added 10 commits March 12, 2025 21:45
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
@yasahi-hpc yasahi-hpc force-pushed the implement-batched-serial-gbtrf branch from 01205a5 to 8e75aff Compare March 12, 2025 13:12
@yasahi-hpc
Copy link
Contributor Author

@lucbv Thank you for your review.
I have fixed except for the error analysis.

For the error analysis, as commented also in #2530,
I would like to start from the simpler cases.

@yasahi-hpc yasahi-hpc requested a review from lucbv March 12, 2025 14:04
@yasahi-hpc yasahi-hpc added AT2-CI-APPROVAL Approve CI to run at SNL and removed AT2-CI-APPROVAL Approve CI to run at SNL labels Mar 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AT2-CI-APPROVAL Approve CI to run at SNL
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants