SourceXtractorPlusPlus  0.15
Please provide a description of the project.
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
DFT.h
Go to the documentation of this file.
1 
17 /*
18  * @file SEFramework/Convolution/DFT.h
19  * @date 17/09/18
20  * @author aalvarez
21  */
22 
23 #ifndef _SEFRAMEWORK_CONVOLUTION_DFT_H
24 #define _SEFRAMEWORK_CONVOLUTION_DFT_H
25 
31 #include "SEFramework/FFT/FFT.h"
32 
33 #include <fftw3.h>
34 
35 
36 namespace SourceXtractor {
37 
45 template<typename T = SeFloat, class TPadding = PaddedImage<T, Reflect101Coordinates>>
47 public:
48  typedef T real_t;
49  typedef typename FFT<T>::complex_t complex_t;
50 
57  private:
62 
63  friend class DFTConvolution<T, TPadding>;
64  };
65 
72  : m_kernel{img} {
73  }
74 
78  virtual ~DFTConvolution() = default;
79 
85  return m_kernel->getWidth();
86  }
87 
93  return m_kernel->getHeight();
94  }
95 
105  auto context = Euclid::make_unique<ConvolutionContext>();
106 
107  // Dimension of the working padded images
108  context->m_padded_width = model_ptr->getWidth() + m_kernel->getWidth() - 1;
109  context->m_padded_height = model_ptr->getHeight() + m_kernel->getHeight() - 1;
110 
111  // For performance, use a size that is convenient for FFTW
112  context->m_padded_width = fftRoundDimension(context->m_padded_width);
113  context->m_padded_height = fftRoundDimension(context->m_padded_height);
114 
115  // Total number of pixels
116  context->m_total_size = context->m_padded_width * context->m_padded_height;
117 
118  // Pre-allocate buffers for the transformations
119  context->m_real_buffer.resize(context->m_total_size);
120  context->m_complex_buffer.resize(context->m_total_size);
121  context->m_kernel_transform.resize(context->m_total_size);
122 
123  // Since we already have the buffers, get the plans too
124  context->m_fwd_plan = FFT<T>::createForwardPlan(1, context->m_padded_width, context->m_padded_height,
125  context->m_real_buffer, context->m_complex_buffer);
126  context->m_inv_plan = FFT<T>::createInversePlan(1, context->m_padded_width, context->m_padded_height,
127  context->m_complex_buffer, context->m_real_buffer);
128 
129  // Transform here the kernel into frequency space
130  padKernel(context->m_padded_width, context->m_padded_height, context->m_real_buffer.begin());
131  FFT<T>::executeForward(context->m_fwd_plan, context->m_real_buffer, context->m_complex_buffer);
132  std::copy(std::begin(context->m_complex_buffer), std::end(context->m_complex_buffer),
133  std::begin(context->m_kernel_transform));
134 
135  return context;
136  }
137 
150  template <typename ...Args>
153  Args... padding_args) const {
154  assert(image_ptr->getWidth() <= context->m_padded_width);
155  assert(image_ptr->getHeight() <= context->m_padded_height);
156 
157  // Padded image
158  auto padded = TPadding::create(image_ptr,
159  context->m_padded_width, context->m_padded_height,
160  std::forward<Args>(padding_args)...);
161 
162  // Create a matrix with the padded image
163  dumpImage(padded, context->m_real_buffer.begin());
164 
165  // Transform the image
166  FFT<T>::executeForward(context->m_fwd_plan, context->m_real_buffer, context->m_complex_buffer);
167 
168  // Multiply the two DFT
169  for (int i = 0; i < context->m_total_size; ++i) {
170  //context->m_complex_buffer[i] *= context->m_kernel_transform[i];
171 
172  const auto& a = context->m_complex_buffer[i];
173  const auto& b = context->m_kernel_transform[i];
174  float re = a.real() * b.real() - a.imag() * b.imag();
175  float im = a.real() * b.imag() + a.imag() * b.real();
176 
177  context->m_complex_buffer[i] = std::complex<float>(re, im);
178  }
179 
180  // Inverse DFT
181  FFT<T>::executeInverse(context->m_inv_plan, context->m_complex_buffer, context->m_real_buffer);
182 
183  // Copy to the output, removing the pad
184  auto wpad = (context->m_padded_width - image_ptr->getWidth()) / 2;
185  auto hpad = (context->m_padded_height - image_ptr->getHeight()) / 2;
186  for (int y = 0; y < image_ptr->getHeight(); ++y) {
187  for (int x = 0; x < image_ptr->getWidth(); ++x) {
188  image_ptr->setValue(x, y,
189  context->m_real_buffer[x + wpad + (y + hpad) * context->m_padded_width] / context->m_total_size);
190  }
191  }
192  }
193 
206  template <typename ...Args>
207  void convolve(std::shared_ptr<WriteableImage<T>> image_ptr, Args... padding_args) const {
208  auto context = prepare(image_ptr);
209  convolve(image_ptr, context, std::forward(padding_args)...);
210  }
211 
217  return m_kernel;
218  }
219 
220 protected:
221  void padKernel(int width, int height, typename std::vector<T>::iterator out) const {
222  auto padded = PaddedImage<T>::create(m_kernel, width, height);
223  auto center = PixelCoordinate{width / 2, height / 2};
224  if (width % 2 == 0) center.m_x--;
225  if (height % 2 == 0) center.m_y--;
226  auto recenter = RecenterImage<T>::create(padded, center);
227 
228  dumpImage(recenter, out);
229  }
230 
231  void dumpImage(const std::shared_ptr<const Image<T>> &img, typename std::vector<T>::iterator out) const {
232  auto chunk = img->getChunk(0, 0, img->getWidth(), img->getHeight());
233  for (int y = 0; y < chunk->getHeight(); ++y) {
234  for (int x = 0; x < chunk->getWidth(); ++x) {
235  *out++ = chunk->getValue(x, y);
236  }
237  }
238  }
239 
240 private:
242 };
243 
244 } // end SourceXtractor
245 
246 #endif // _SEFRAMEWORK_CONVOLUTION_DFT_H
T copy(T...args)
void convolve(std::shared_ptr< WriteableImage< T >> image_ptr, std::unique_ptr< ConvolutionContext > &context, Args...padding_args) const
Definition: DFT.h:151
static plan_ptr_t createForwardPlan(int howmany, int width, int height, std::vector< T > &in, std::vector< complex_t > &out)
Definition: FFT.cpp:155
void convolve(std::shared_ptr< WriteableImage< T >> image_ptr, Args...padding_args) const
Definition: DFT.h:207
std::vector< complex_t > m_complex_buffer
Definition: DFT.h:59
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > x
static std::shared_ptr< RecenterImage< T > > create(Args &&...args)
Definition: RecenterImage.h:44
static std::shared_ptr< PaddedImage< T, CoordinateInterpolation > > create(Args &&...args)
Definition: PaddedImage.h:89
static plan_ptr_t createInversePlan(int howmany, int width, int height, std::vector< complex_t > &in, std::vector< T > &out)
Definition: FFT.cpp:199
std::unique_ptr< ConvolutionContext > prepare(const std::shared_ptr< Image< T >> model_ptr) const
Definition: DFT.h:104
STL class.
T end(T...args)
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > y
void padKernel(int width, int height, typename std::vector< T >::iterator out) const
Definition: DFT.h:221
DFTConvolution(std::shared_ptr< const Image< T >> img)
Definition: DFT.h:71
static void executeInverse(plan_ptr_t &plan, std::vector< complex_t > &in, std::vector< T > &out)
Definition: FFT.cpp:250
void dumpImage(const std::shared_ptr< const Image< T >> &img, typename std::vector< T >::iterator out) const
Definition: DFT.h:231
A pixel coordinate made of two integers m_x and m_y.
STL class.
int fftRoundDimension(int size)
Definition: FFT.cpp:49
std::size_t getHeight() const
Definition: DFT.h:92
STL class.
T begin(T...args)
std::shared_ptr< const Image< T > > m_kernel
Definition: DFT.h:241
virtual ~DFTConvolution()=default
Interface representing an image.
Definition: Image.h:43
std::vector< complex_t > m_kernel_transform
Definition: DFT.h:59
FFT< T >::complex_t complex_t
Definition: DFT.h:49
static void executeForward(plan_ptr_t &plan, std::vector< T > &in, std::vector< complex_t > &out)
Definition: FFT.cpp:244
std::shared_ptr< const Image< T > > getKernel() const
Definition: DFT.h:216
T forward(T...args)
std::size_t getWidth() const
Definition: DFT.h:84