Skip to content

Commit 7d19509

Browse files
committed
ARROW-13540: [C++] Refactor and clean up sink node
1 parent 554f094 commit 7d19509

File tree

5 files changed

+41
-41
lines changed

5 files changed

+41
-41
lines changed

cpp/src/arrow/compute/exec/options.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions {
115115
/// \brief Make a node which sorts rows passed through it
116116
///
117117
/// All batches pushed to this node will be accumulated, then sorted, by the given
118-
/// fields. Then sorted batches will be pushed to the next node, along a tag
119-
/// indicating the absolute order of the batches.
118+
/// fields. Then sorted batches will be forwarded to the generator in sorted order.
120119
class ARROW_EXPORT OrderBySinkNodeOptions : public SinkNodeOptions {
121120
public:
122121
explicit OrderBySinkNodeOptions(

cpp/src/arrow/compute/exec/plan_test.cc

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "arrow/compute/exec/expression.h"
2626
#include "arrow/compute/exec/options.h"
2727
#include "arrow/compute/exec/test_util.h"
28+
#include "arrow/compute/exec/util.h"
2829
#include "arrow/record_batch.h"
2930
#include "arrow/table.h"
3031
#include "arrow/testing/future_util.h"
@@ -37,6 +38,7 @@
3738
#include "arrow/util/vector.h"
3839

3940
using testing::ElementsAre;
41+
using testing::ElementsAreArray;
4042
using testing::HasSubstr;
4143
using testing::Optional;
4244
using testing::UnorderedElementsAreArray;
@@ -328,7 +330,7 @@ TEST(ExecPlanExecution, SourceOrderBy) {
328330
.AddToPlan(plan.get()));
329331

330332
ASSERT_THAT(StartAndCollect(plan.get(), sink_gen),
331-
Finishes(ResultWith(::testing::ElementsAreArray(expected))));
333+
Finishes(ResultWith(ElementsAreArray(expected))));
332334
}
333335
}
334336
}
@@ -414,18 +416,9 @@ TEST(ExecPlanExecution, StressSourceOrderBy) {
414416
// Check that data is sorted appropriately
415417
ASSERT_FINISHES_OK_AND_ASSIGN(auto exec_batches,
416418
StartAndCollect(plan.get(), sink_gen));
417-
RecordBatchVector batches, original_batches;
418-
for (const auto& batch : exec_batches) {
419-
ASSERT_OK_AND_ASSIGN(auto rb, batch.ToRecordBatch(input_schema));
420-
batches.push_back(std::move(rb));
421-
}
422-
for (const auto& batch : random_data.batches) {
423-
ASSERT_OK_AND_ASSIGN(auto rb, batch.ToRecordBatch(input_schema));
424-
original_batches.push_back(std::move(rb));
425-
}
426-
ASSERT_OK_AND_ASSIGN(auto actual, Table::FromRecordBatches(input_schema, batches));
419+
ASSERT_OK_AND_ASSIGN(auto actual, TableFromExecBatches(input_schema, exec_batches));
427420
ASSERT_OK_AND_ASSIGN(auto original,
428-
Table::FromRecordBatches(input_schema, original_batches));
421+
TableFromExecBatches(input_schema, random_data.batches));
429422
ASSERT_OK_AND_ASSIGN(auto sort_indices, SortIndices(original, options));
430423
ASSERT_OK_AND_ASSIGN(auto expected, Take(original, sort_indices));
431424
AssertTablesEqual(*actual, *expected.table());

cpp/src/arrow/compute/exec/sink_node.cc

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -187,41 +187,34 @@ struct OrderBySinkNode final : public SinkNode {
187187
}
188188

189189
protected:
190-
Result<std::shared_ptr<Table>> SortData() {
191-
std::unique_lock<std::mutex> lock(mutex_);
192-
ARROW_ASSIGN_OR_RAISE(
193-
auto table,
194-
Table::FromRecordBatches(inputs_[0]->output_schema(), std::move(batches_)));
195-
ARROW_ASSIGN_OR_RAISE(auto indices,
196-
SortIndices(table, sort_options_, plan()->exec_context()));
197-
ARROW_ASSIGN_OR_RAISE(auto sorted, Take(table, indices, TakeOptions::NoBoundsCheck(),
198-
plan()->exec_context()));
199-
return sorted.table();
200-
}
201-
202-
void Finish() override {
203-
auto maybe_sorted = SortData();
204-
if (ErrorIfNotOk(maybe_sorted.status())) {
205-
producer_.Push(maybe_sorted.status());
206-
SinkNode::Finish();
207-
return;
190+
Status DoFinish() {
191+
Datum sorted;
192+
{
193+
std::unique_lock<std::mutex> lock(mutex_);
194+
ARROW_ASSIGN_OR_RAISE(
195+
auto table,
196+
Table::FromRecordBatches(inputs_[0]->output_schema(), std::move(batches_)));
197+
ARROW_ASSIGN_OR_RAISE(auto indices,
198+
SortIndices(table, sort_options_, plan()->exec_context()));
199+
ARROW_ASSIGN_OR_RAISE(sorted, Take(table, indices, TakeOptions::NoBoundsCheck(),
200+
plan()->exec_context()));
208201
}
209-
auto sorted = maybe_sorted.MoveValueUnsafe();
210-
211-
TableBatchReader reader(*sorted);
202+
TableBatchReader reader(*sorted.table());
212203
while (true) {
213204
std::shared_ptr<RecordBatch> batch;
214-
auto status = reader.ReadNext(&batch);
215-
if (!status.ok()) {
216-
producer_.Push(std::move(status));
217-
SinkNode::Finish();
218-
return;
219-
}
205+
RETURN_NOT_OK(reader.ReadNext(&batch));
220206
if (!batch) break;
221207
bool did_push = producer_.Push(ExecBatch(*batch));
222208
if (!did_push) break; // producer_ was Closed already
223209
}
210+
return Status::OK();
211+
}
224212

213+
void Finish() override {
214+
Status st = DoFinish();
215+
if (ErrorIfNotOk(st)) {
216+
producer_.Push(std::move(st));
217+
}
225218
SinkNode::Finish();
226219
}
227220

cpp/src/arrow/compute/exec/util.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "arrow/compute/exec/util.h"
1919

2020
#include "arrow/compute/exec/exec_plan.h"
21+
#include "arrow/table.h"
2122
#include "arrow/util/bit_util.h"
2223
#include "arrow/util/bitmap_ops.h"
2324
#include "arrow/util/ubsan.h"
@@ -296,5 +297,15 @@ Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector<ExecNode*>& inpu
296297
return Status::OK();
297298
}
298299

300+
Result<std::shared_ptr<Table>> TableFromExecBatches(
301+
const std::shared_ptr<Schema>& schema, const std::vector<ExecBatch>& exec_batches) {
302+
RecordBatchVector batches;
303+
for (const auto& batch : exec_batches) {
304+
ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema));
305+
batches.push_back(std::move(rb));
306+
}
307+
return Table::FromRecordBatches(schema, batches);
308+
}
309+
299310
} // namespace compute
300311
} // namespace arrow

cpp/src/arrow/compute/exec/util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ ARROW_EXPORT
188188
Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector<ExecNode*>& inputs,
189189
int expected_num_inputs, const char* kind_name);
190190

191+
ARROW_EXPORT
192+
Result<std::shared_ptr<Table>> TableFromExecBatches(
193+
const std::shared_ptr<Schema>& schema, const std::vector<ExecBatch>& exec_batches);
194+
191195
class AtomicCounter {
192196
public:
193197
AtomicCounter() = default;

0 commit comments

Comments
 (0)