Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:00:12

0001 /**
0002  * Copyright (c) 2018-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <vector>
0013 
0014 #include "gloo/context.h"
0015 #include "gloo/transport/unbound_buffer.h"
0016 
0017 namespace gloo {
0018 
0019 class ScatterOptions {
0020  public:
0021   explicit ScatterOptions(const std::shared_ptr<Context>& context)
0022       : context(context), timeout(context->getTimeout()) {}
0023 
0024   template <typename T>
0025   void setInputs(std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs) {
0026     this->elementSize = sizeof(T);
0027     this->in = std::move(bufs);
0028   }
0029 
0030   template <typename T>
0031   void setInputs(std::vector<T*> ptrs, size_t elements) {
0032     setInputs(ptrs.data(), ptrs.size(), elements);
0033   }
0034 
0035   template <typename T>
0036   void setInputs(T** ptrs, size_t len, size_t elements) {
0037     this->elementSize = sizeof(T);
0038     this->in.reserve(len);
0039     for (size_t i = 0; i < len; i++) {
0040       this->in.push_back(
0041           context->createUnboundBuffer(ptrs[i], elements * sizeof(T)));
0042     }
0043   }
0044 
0045   template <typename T>
0046   void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) {
0047     this->elementSize = sizeof(T);
0048     this->out = std::move(buf);
0049   }
0050 
0051   template <typename T>
0052   void setOutput(T* ptr, size_t elements) {
0053     this->elementSize = sizeof(T);
0054     this->out = context->createUnboundBuffer(ptr, elements * sizeof(T));
0055   }
0056 
0057   void setRoot(int root) {
0058     this->root = root;
0059   }
0060 
0061   void setTag(uint32_t tag) {
0062     this->tag = tag;
0063   }
0064 
0065   void setTimeout(std::chrono::milliseconds timeout) {
0066     this->timeout = timeout;
0067   }
0068 
0069  protected:
0070   std::shared_ptr<Context> context;
0071 
0072   // Scatter has N input buffers where each one in its
0073   // entirety gets sent to a rank. The input(s) only need to
0074   // be set on the root process.
0075   std::vector<std::unique_ptr<transport::UnboundBuffer>> in;
0076 
0077   // Scatter only has a single output buffer per rank.
0078   std::unique_ptr<transport::UnboundBuffer> out;
0079 
0080   // Number of bytes per element.
0081   size_t elementSize = 0;
0082 
0083   // Rank of process to scatter from.
0084   int root = -1;
0085 
0086   // Tag for this operation.
0087   // Must be unique across operations executing in parallel.
0088   uint32_t tag = 0;
0089 
0090   // End-to-end timeout for this operation.
0091   std::chrono::milliseconds timeout;
0092 
0093   friend void scatter(ScatterOptions&);
0094 };
0095 
0096 void scatter(ScatterOptions& opts);
0097 
0098 } // namespace gloo