Fork me on GitHub
 All Classes Files Functions Variables Groups Pages
blas.h
Go to the documentation of this file.
1 /*
2  * Copyright 2008-2014 NVIDIA Corporation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
21 #pragma once
22 
23 #include <cusp/detail/config.h>
24 
25 #include <cusp/complex.h>
26 #include <cusp/detail/type_traits.h>
27 
28 namespace cusp
29 {
30 namespace blas
31 {
32 
41 template <typename DerivedPolicy,
42  typename ArrayType>
43 int amax(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
44  const ArrayType& x);
79 template <typename ArrayType>
80 int amax(const ArrayType& x);
81 
83 template <typename DerivedPolicy,
84  typename ArrayType>
85 typename cusp::norm_type<typename ArrayType::value_type>::type
86 asum(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
87  const ArrayType& x);
122 template <typename ArrayType>
123 typename cusp::norm_type<typename ArrayType::value_type>::type
124 asum(const ArrayType& x);
125 
127 template <typename DerivedPolicy,
128  typename ArrayType1,
129  typename ArrayType2,
130  typename ScalarType>
131 void axpy(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
132  const ArrayType1& x,
133  ArrayType2& y,
134  const ScalarType alpha);
175 template <typename ArrayType1,
176  typename ArrayType2,
177  typename ScalarType>
178 void axpy(const ArrayType1& x,
179  ArrayType2& y,
180  const ScalarType alpha);
181 
183 template <typename DerivedPolicy,
184  typename ArrayType1,
185  typename ArrayType2,
186  typename ArrayType3,
187  typename ScalarType1,
188  typename ScalarType2>
189 void axpby(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
190  const ArrayType1& x,
191  const ArrayType2& y,
192  ArrayType3& z,
193  ScalarType1 alpha,
194  ScalarType2 beta);
244 template <typename ArrayType1,
245  typename ArrayType2,
246  typename ArrayType3,
247  typename ScalarType1,
248  typename ScalarType2>
249 void axpby(const ArrayType1& x,
250  const ArrayType2& y,
251  ArrayType3& z,
252  ScalarType1 alpha,
253  ScalarType2 beta);
254 
256 template <typename DerivedPolicy,
257  typename ArrayType1,
258  typename ArrayType2,
259  typename ArrayType3,
260  typename ArrayType4,
261  typename ScalarType1,
262  typename ScalarType2,
263  typename ScalarType3>
264 void axpbypcz(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
265  const ArrayType1& x,
266  const ArrayType2& y,
267  const ArrayType3& z,
268  ArrayType4& output,
269  ScalarType1 alpha,
270  ScalarType2 beta,
271  ScalarType3 gamma);
330 template <typename ArrayType1,
331  typename ArrayType2,
332  typename ArrayType3,
333  typename ArrayType4,
334  typename ScalarType1,
335  typename ScalarType2,
336  typename ScalarType3>
337 void axpbypcz(const ArrayType1& x,
338  const ArrayType2& y,
339  const ArrayType3& z,
340  ArrayType4& w,
341  ScalarType1 alpha,
342  ScalarType2 beta,
343  ScalarType3 gamma);
344 
346 template <typename DerivedPolicy,
347  typename ArrayType1,
348  typename ArrayType2,
349  typename ArrayType3>
350 void xmy(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
351  const ArrayType1& x,
352  const ArrayType2& y,
353  ArrayType3& z);
393 template <typename ArrayType1,
394  typename ArrayType2,
395  typename ArrayType3>
396 void xmy(const ArrayType1& x,
397  const ArrayType2& y,
398  ArrayType3& z);
399 
401 template <typename DerivedPolicy,
402  typename ArrayType1,
403  typename ArrayType2>
404 void copy(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
405  const ArrayType1& x,
406  ArrayType2& y);
441 template <typename ArrayType1,
442  typename ArrayType2>
443 void copy(const ArrayType1& x,
444  ArrayType2& y);
445 
447 template <typename DerivedPolicy,
448  typename ArrayType1,
449  typename ArrayType2>
450 typename ArrayType1::value_type
451 dot(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
452  const ArrayType1& x,
453  const ArrayType2& y);
488 template <typename ArrayType1,
489  typename ArrayType2>
490 typename ArrayType1::value_type
491 dot(const ArrayType1& x,
492  const ArrayType2& y);
493 
495 template <typename DerivedPolicy,
496  typename ArrayType1,
497  typename ArrayType2>
498 typename ArrayType1::value_type
499 dotc(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
500  const ArrayType1& x,
501  const ArrayType2& y);
536 template <typename ArrayType1,
537  typename ArrayType2>
538 typename ArrayType1::value_type
539 dotc(const ArrayType1& x,
540  const ArrayType2& y);
541 
543 template <typename DerivedPolicy,
544  typename ArrayType,
545  typename ScalarType>
546 void fill(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
547  ArrayType& array,
548  const ScalarType alpha);
580 template <typename ArrayType,
581  typename ScalarType>
582 void fill(ArrayType& x,
583  const ScalarType alpha);
584 
586 template <typename DerivedPolicy,
587  typename ArrayType>
588 typename cusp::norm_type<typename ArrayType::value_type>::type
589 nrm1(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
590  const ArrayType& array);
623 template <typename ArrayType>
624 typename cusp::norm_type<typename ArrayType::value_type>::type
625 nrm1(const ArrayType& x);
626 
628 template <typename DerivedPolicy,
629  typename ArrayType>
630 typename cusp::norm_type<typename ArrayType::value_type>::type
631 nrm2(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
632  const ArrayType& x);
665 template <typename ArrayType>
666 typename cusp::norm_type<typename ArrayType::value_type>::type
667 nrm2(const ArrayType& x);
668 
670 template <typename DerivedPolicy,
671  typename ArrayType>
672 typename cusp::norm_type<typename ArrayType::value_type>::type
673 nrmmax(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
674  const ArrayType& x);
707 template <typename ArrayType>
708 typename cusp::norm_type<typename ArrayType::value_type>::type
709 nrmmax(const ArrayType& x);
710 
712 template <typename DerivedPolicy,
713  typename ArrayType,
714  typename ScalarType>
715 void scal(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
716  ArrayType& x,
717  const ScalarType alpha);
752 template <typename ArrayType,
753  typename ScalarType>
754 void scal(ArrayType& x,
755  const ScalarType alpha);
756 
758 template <typename DerivedPolicy,
759  typename Array2d1,
760  typename Array1d1,
761  typename Array1d2>
762 void gemv(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
763  const Array2d1& A,
764  const Array1d1& x,
765  Array1d2& y);
815 template<typename Array2d1,
816  typename Array1d1,
817  typename Array1d2>
818 void gemv(const Array2d1& A,
819  const Array1d1& x,
820  Array1d2& y);
821 
823 template <typename DerivedPolicy,
824  typename Array1d1,
825  typename Array1d2,
826  typename Array2d1>
827 void ger(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
828  const Array1d1& x,
829  const Array1d2& y,
830  Array2d1& A);
877 template<typename Array1d1,
878  typename Array1d2,
879  typename Array2d1>
880 void ger(const Array1d1& x,
881  const Array1d2& y,
882  Array2d1& A);
883 
885 template <typename DerivedPolicy,
886  typename Array2d1,
887  typename Array1d1,
888  typename Array1d2>
889 void symv(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
890  const Array2d1& A,
891  const Array1d1& x,
892  Array1d2& y);
942 template <typename Array2d1,
943  typename Array1d1,
944  typename Array1d2>
945 void symv(const Array2d1& A,
946  const Array1d1& x,
947  Array1d2& y);
948 
950 template <typename DerivedPolicy,
951  typename Array1d,
952  typename Array2d>
953 void syr(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
954  const Array1d& x,
955  Array2d& A);
997 template <typename Array1d,
998  typename Array2d>
999 void syr(const Array1d& x,
1000  Array2d& A);
1001 
1003 template <typename DerivedPolicy,
1004  typename Array2d,
1005  typename Array1d>
1006 void trmv(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
1007  const Array2d& A,
1008  Array1d& x);
1058 template<typename Array2d,
1059  typename Array1d>
1060 void trmv(const Array2d& A,
1061  Array1d& x);
1062 
1064 template <typename DerivedPolicy,
1065  typename Array2d,
1066  typename Array1d>
1067 void trsv(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
1068  const Array2d& A,
1069  Array1d& x);
1114 template<typename Array2d,
1115  typename Array1d>
1116 void trsv(const Array2d& A,
1117  Array1d& x);
1118 
1120 template <typename DerivedPolicy,
1121  typename Array2d1,
1122  typename Array2d2,
1123  typename Array2d3>
1124 void gemm(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
1125  const Array2d1& A,
1126  const Array2d2& B,
1127  Array2d3& C);
1172 template<typename Array2d1,
1173  typename Array2d2,
1174  typename Array2d3>
1175 void gemm(const Array2d1& A,
1176  const Array2d2& B,
1177  Array2d3& C);
1178 
1180 template <typename DerivedPolicy,
1181  typename Array2d1,
1182  typename Array2d2,
1183  typename Array2d3>
1184 void symm(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
1185  const Array2d1& A,
1186  const Array2d2& B,
1187  Array2d3& C);
1232 template<typename Array2d1,
1233  typename Array2d2,
1234  typename Array2d3>
1235 void symm(const Array2d1& A,
1236  const Array2d2& B,
1237  Array2d3& C);
1238 
1240 template <typename DerivedPolicy,
1241  typename Array2d1,
1242  typename Array2d2>
1243 void syrk(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
1244  const Array2d1& A,
1245  Array2d2& B);
1290 template<typename Array2d1,
1291  typename Array2d2>
1292 void syrk(const Array2d1& A,
1293  Array2d2& B);
1294 
1296 template <typename DerivedPolicy,
1297  typename Array2d1,
1298  typename Array2d2,
1299  typename Array2d3>
1300 void syr2k(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
1301  const Array2d1& A,
1302  const Array2d2& B,
1303  Array2d3& C);
1348 template<typename Array2d1,
1349  typename Array2d2,
1350  typename Array2d3>
1351 void syr2k(const Array2d1& A,
1352  const Array2d2& B,
1353  Array2d3& C);
1354 
1356 template <typename DerivedPolicy,
1357  typename Array2d1,
1358  typename Array2d2>
1359 void trmm(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
1360  const Array2d1& A,
1361  Array2d2& B);
1406 template<typename Array2d1,
1407  typename Array2d2>
1408 void trmm(const Array2d1& A,
1409  Array2d2& B);
1410 
1412 template <typename DerivedPolicy,
1413  typename Array2d1,
1414  typename Array2d2>
1415 void trsm(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
1416  const Array2d1& A,
1417  Array2d2& B);
1464 template<typename Array2d1,
1465  typename Array2d2>
1466 void trsm(const Array2d1& A,
1467  Array2d2& B);
1468 
1472 } // end namespace blas
1473 } // end namespace cusp
1474 
1475 #include <cusp/blas/blas.inl>
void ger(const Array1d1 &x, const Array1d2 &y, Array2d1 &A)
Performs a rank-1 update of a general matrix.
void trmv(const Array2d &A, Array1d &x)
Computes a matrix-vector product using a triangular matrix.
void symv(const Array2d1 &A, const Array1d1 &x, Array1d2 &y)
Computes a matrix-vector product using a symmetric matrix.
void symm(const Array2d1 &A, const Array2d2 &B, Array2d3 &C)
Computes a matrix-matrix product where one input matrix is symmetric.
Complex numbers.
void syrk(const Array2d1 &A, Array2d2 &B)
Performs a symmetric rank-k update.
void gemm(const Array2d1 &A, const Array2d2 &B, Array2d3 &C)
Computes a matrix-matrix product with general matrices.
cusp::norm_type< typename ArrayType::value_type >::type asum(const ArrayType &x)
sum of absolute value of all entries in array
cusp::norm_type< typename ArrayType::value_type >::type nrm2(const ArrayType &x)
vector 2-norm (sqrt(sum x[i] * x[i] )
void scal(ArrayType &x, const ScalarType alpha)
scale vector (x[i] = alpha * x[i])
void axpbypcz(const ArrayType1 &x, const ArrayType2 &y, const ArrayType3 &z, ArrayType4 &w, ScalarType1 alpha, ScalarType2 beta, ScalarType3 gamma)
compute linear combination of three vectors (output = alpha * x + beta * y + gamma * z) ...
void syr(const Array1d &x, Array2d &A)
Performs a rank-1 update of a symmetric matrix.
ArrayType1::value_type dotc(const ArrayType1 &x, const ArrayType2 &y)
conjugate dot product (conjugate(x)^T * y)
cusp::norm_type< typename ArrayType::value_type >::type nrm1(const ArrayType &x)
vector 1-norm (sum abs(x[i]))
void trsv(const Array2d &A, Array1d &x)
Solve a triangular matrix equation.
ArrayType1::value_type dot(const ArrayType1 &x, const ArrayType2 &y)
dot product (x^T * y)
void axpby(const ArrayType1 &x, const ArrayType2 &y, ArrayType3 &z, ScalarType1 alpha, ScalarType2 beta)
compute linear combination of two vectors (z = alpha * x + beta * y)
void copy(const ArrayType1 &x, ArrayType2 &y)
vector copy (y = x)
void axpy(const ArrayType1 &x, ArrayType2 &y, const ScalarType alpha)
scaled vector addition (y = alpha * x + y)
void syr2k(const Array2d1 &A, const Array2d2 &B, Array2d3 &C)
Performs a symmetric rank-2k update.
int amax(const ArrayType &x)
index of the largest element in a array
cusp::norm_type< typename ArrayType::value_type >::type nrmmax(const ArrayType &x)
vector infinity norm
void gemv(const Array2d1 &A, const Array1d1 &x, Array1d2 &y)
Computes a matrix-vector product using a general matrix.
void fill(ArrayType &x, const ScalarType alpha)
vector fill (x[i] = alpha)
void trsm(const Array2d1 &A, Array2d2 &B)
Solve a triangular matrix equation.
void trmm(const Array2d1 &A, Array2d2 &B)
Computes a matrix-matrix product where one input matrix is triangular.
void xmy(const ArrayType1 &x, const ArrayType2 &y, ArrayType3 &z)
elementwise multiplication of two vectors (z[i] = x[i] * y[i])