/*=========================================================================
 *
 *  Copyright NumFOCUS
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *         https://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/
#ifndef itkKdTreeGenerator_hxx
#define itkKdTreeGenerator_hxx


namespace itk
{
namespace Statistics
{
template <typename TSample>
KdTreeGenerator<TSample>::KdTreeGenerator()
{
  m_SourceSample = nullptr;
  m_BucketSize = 16;
  m_Subsample = SubsampleType::New();
  m_MeasurementVectorSize = 0;
}

template <typename TSample>
void
KdTreeGenerator<TSample>::PrintSelf(std::ostream & os, Indent indent) const
{
  Superclass::PrintSelf(os, indent);

  os << indent << "Source Sample: ";
  if (m_SourceSample != nullptr)
  {
    os << m_SourceSample << std::endl;
  }
  else
  {
    os << "not set." << std::endl;
  }

  os << indent << "Bucket Size: " << m_BucketSize << std::endl;
  os << indent << "MeasurementVectorSize: " << m_MeasurementVectorSize << std::endl;
}

template <typename TSample>
void
KdTreeGenerator<TSample>::SetSample(TSample * sample)
{
  m_SourceSample = sample;
  m_Subsample->SetSample(sample);
  m_Subsample->InitializeWithAllInstances();
  m_MeasurementVectorSize = sample->GetMeasurementVectorSize();
  NumericTraits<MeasurementVectorType>::SetLength(m_TempLowerBound, m_MeasurementVectorSize);
  NumericTraits<MeasurementVectorType>::SetLength(m_TempUpperBound, m_MeasurementVectorSize);
  NumericTraits<MeasurementVectorType>::SetLength(m_TempMean, m_MeasurementVectorSize);
}

template <typename TSample>
void
KdTreeGenerator<TSample>::SetBucketSize(unsigned int size)
{
  m_BucketSize = size;
}

template <typename TSample>
void
KdTreeGenerator<TSample>::GenerateData()
{
  if (m_SourceSample == nullptr)
  {
    return;
  }

  if (m_Tree.IsNull())
  {
    m_Tree = KdTreeType::New();
    m_Tree->SetSample(m_SourceSample);
    m_Tree->SetBucketSize(m_BucketSize);
  }

  SubsamplePointer subsample = this->GetSubsample();

  // Sanity check. Verify that the subsample has measurement vectors of the
  // same length as the sample generated by the tree.
  if (this->GetMeasurementVectorSize() != subsample->GetMeasurementVectorSize())
  {
    itkExceptionMacro(<< "Measurement Vector Length mismatch");
  }

  MeasurementVectorType lowerBound;
  NumericTraits<MeasurementVectorType>::SetLength(lowerBound, m_MeasurementVectorSize);
  MeasurementVectorType upperBound;
  NumericTraits<MeasurementVectorType>::SetLength(upperBound, m_MeasurementVectorSize);

  for (unsigned int d = 0; d < m_MeasurementVectorSize; ++d)
  {
    lowerBound[d] = NumericTraits<MeasurementType>::NonpositiveMin();
    upperBound[d] = NumericTraits<MeasurementType>::max();
  }

  KdTreeNodeType * root = this->GenerateTreeLoop(0, m_Subsample->Size(), lowerBound, upperBound, 0);
  m_Tree->SetRoot(root);
}

template <typename TSample>
inline typename KdTreeGenerator<TSample>::KdTreeNodeType *
KdTreeGenerator<TSample>::GenerateNonterminalNode(unsigned int            beginIndex,
                                                  unsigned int            endIndex,
                                                  MeasurementVectorType & lowerBound,
                                                  MeasurementVectorType & upperBound,
                                                  unsigned int            level)
{
  using NodeType = typename KdTreeType::KdTreeNodeType;
  MeasurementType dimensionLowerBound;
  MeasurementType dimensionUpperBound;
  MeasurementType partitionValue;
  unsigned int    partitionDimension = 0;
  unsigned int    i;
  MeasurementType spread;
  MeasurementType maxSpread;
  unsigned int    medianIndex;

  SubsamplePointer subsample = this->GetSubsample();

  // find most widely spread dimension
  Algorithm::FindSampleBoundAndMean<SubsampleType>(
    subsample, beginIndex, endIndex, m_TempLowerBound, m_TempUpperBound, m_TempMean);

  maxSpread = NumericTraits<MeasurementType>::NonpositiveMin();
  for (i = 0; i < m_MeasurementVectorSize; ++i)
  {
    spread = m_TempUpperBound[i] - m_TempLowerBound[i];
    if (spread >= maxSpread)
    {
      maxSpread = spread;
      partitionDimension = i;
    }
  }

  medianIndex = (endIndex - beginIndex) / 2;

  //
  // Find the medial element by using the NthElement function
  // based on the STL implementation of the QuickSelect algorithm.
  //
  partitionValue =
    Algorithm::NthElement<SubsampleType>(m_Subsample, partitionDimension, beginIndex, endIndex, medianIndex);

  medianIndex += beginIndex;

  // save bounds for cutting dimension
  dimensionLowerBound = lowerBound[partitionDimension];
  dimensionUpperBound = upperBound[partitionDimension];

  upperBound[partitionDimension] = partitionValue;
  const unsigned int beginLeftIndex = beginIndex;
  const unsigned int endLeftIndex = medianIndex;
  NodeType *         left = GenerateTreeLoop(beginLeftIndex, endLeftIndex, lowerBound, upperBound, level + 1);
  upperBound[partitionDimension] = dimensionUpperBound;

  lowerBound[partitionDimension] = partitionValue;
  const unsigned int beginRightIndex = medianIndex + 1;
  const unsigned int endRightIndex = endIndex;
  NodeType *         right = GenerateTreeLoop(beginRightIndex, endRightIndex, lowerBound, upperBound, level + 1);
  lowerBound[partitionDimension] = dimensionLowerBound;

  using KdTreeNonterminalNodeType = KdTreeNonterminalNode<TSample>;

  auto * nonTerminalNode = new KdTreeNonterminalNodeType(partitionDimension, partitionValue, left, right);

  nonTerminalNode->AddInstanceIdentifier(subsample->GetInstanceIdentifier(medianIndex));

  return nonTerminalNode;
}

template <typename TSample>
inline typename KdTreeGenerator<TSample>::KdTreeNodeType *
KdTreeGenerator<TSample>::GenerateTreeLoop(unsigned int            beginIndex,
                                           unsigned int            endIndex,
                                           MeasurementVectorType & lowerBound,
                                           MeasurementVectorType & upperBound,
                                           unsigned int            level)
{
  if (endIndex - beginIndex <= m_BucketSize)
  {
    // numberOfInstances small, make a terminal node
    if (endIndex == beginIndex)
    {
      // return the pointer to empty terminal node
      return m_Tree->GetEmptyTerminalNode();
    }
    else
    {
      auto * ptr = new KdTreeTerminalNode<TSample>();

      for (unsigned int j = beginIndex; j < endIndex; ++j)
      {
        ptr->AddInstanceIdentifier(this->GetSubsample()->GetInstanceIdentifier(j));
      }

      // return a terminal node
      return ptr;
    }
  }
  else
  {
    return this->GenerateNonterminalNode(beginIndex, endIndex, lowerBound, upperBound, level + 1);
  }
}
} // end of namespace Statistics
} // end of namespace itk

#endif
