Add arg() to tensor
diff --git a/unsupported/Eigen/CXX11/src/Tensor/README.md b/unsupported/Eigen/CXX11/src/Tensor/README.md
index 2e785ef..e4b5e2e 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/README.md
+++ b/unsupported/Eigen/CXX11/src/Tensor/README.md
@@ -886,6 +886,23 @@
Returns a tensor of the same type and dimensions as the original tensor
containing the absolute values of the original tensor.
+### <Operation> arg()
+
+Returns a tensor with the same dimensions as the original tensor
+containing the complex argument (phase angle) of the values of the
+original tensor.
+
+### <Operation> real()
+
+Returns a tensor with the same dimensions as the original tensor
+containing the real part of the complex values of the original tensor.
+
+### <Operation> imag()
+
+Returns a tensor with the same dimensions as the orginal tensor
+containing the imaginary part of the complex values of the original
+tensor.
+
### <Operation> pow(Scalar exponent)
Returns a tensor of the same type and dimensions as the original tensor
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index 8eaf96a..3e07a6f 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -312,6 +312,12 @@
}
EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_arg_op<Scalar>, const Derived>
+ arg() const {
+ return unaryExpr(internal::scalar_arg_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_clamp_op<Scalar>, const Derived>
clip(Scalar min, Scalar max) const {
return unaryExpr(internal::scalar_clamp_op<Scalar>(min, max));
diff --git a/unsupported/test/cxx11_tensor_of_complex.cpp b/unsupported/test/cxx11_tensor_of_complex.cpp
index 99e18076..b2f5994 100644
--- a/unsupported/test/cxx11_tensor_of_complex.cpp
+++ b/unsupported/test/cxx11_tensor_of_complex.cpp
@@ -47,6 +47,20 @@
}
}
+static void test_arg()
+{
+ Tensor<std::complex<float>, 1> data1(3);
+ Tensor<std::complex<double>, 1> data2(3);
+ data1.setRandom();
+ data2.setRandom();
+
+ Tensor<float, 1> arg1 = data1.arg();
+ Tensor<double, 1> arg2 = data2.arg();
+ for (int i = 0; i < 3; ++i) {
+ VERIFY_IS_APPROX(arg1(i), std::arg(data1(i)));
+ VERIFY_IS_APPROX(arg2(i), std::arg(data2(i)));
+ }
+}
static void test_conjugate()
{
@@ -98,6 +112,7 @@
{
CALL_SUBTEST(test_additions());
CALL_SUBTEST(test_abs());
+ CALL_SUBTEST(test_arg());
CALL_SUBTEST(test_conjugate());
CALL_SUBTEST(test_contractions());
}