Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:00:11

0001 /**
0002  * Copyright (c) 2018-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include "gloo/common/logging.h"
0012 #include "gloo/context.h"
0013 #include "gloo/transport/unbound_buffer.h"
0014 
0015 namespace gloo {
0016 
0017 class AlltoallOptions {
0018  public:
0019   explicit AlltoallOptions(const std::shared_ptr<Context>& context)
0020       : context(context), timeout(context->getTimeout()) {}
0021 
0022   template <typename T>
0023   void setInput(std::unique_ptr<transport::UnboundBuffer> buf) {
0024     elementSize = sizeof(T);
0025     in = std::move(buf);
0026   }
0027 
0028   template <typename T>
0029   void setInput(T* ptr, size_t elements) {
0030     elementSize = sizeof(T);
0031     in = context->createUnboundBuffer(ptr, elements * sizeof(T));
0032   }
0033 
0034   template <typename T>
0035   void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) {
0036     elementSize = sizeof(T);
0037     out = std::move(buf);
0038   }
0039 
0040   template <typename T>
0041   void setOutput(T* ptr, size_t elements) {
0042     elementSize = sizeof(T);
0043     out = context->createUnboundBuffer(ptr, elements * sizeof(T));
0044   }
0045 
0046   void setTag(uint32_t tag) {
0047     this->tag = tag;
0048   }
0049 
0050   void setTimeout(std::chrono::milliseconds timeout) {
0051     GLOO_ENFORCE(timeout.count() > 0);
0052     this->timeout = timeout;
0053   }
0054 
0055  protected:
0056   std::shared_ptr<Context> context;
0057   std::unique_ptr<transport::UnboundBuffer> in;
0058   std::unique_ptr<transport::UnboundBuffer> out;
0059 
0060   // Number of bytes per element.
0061   size_t elementSize = 0;
0062 
0063   // Tag for this operation.
0064   // Must be unique across operations executing in parallel.
0065   uint32_t tag = 0;
0066 
0067   // End-to-end timeout for this operation.
0068   std::chrono::milliseconds timeout;
0069 
0070   friend void alltoall(AlltoallOptions&);
0071 };
0072 
0073 void alltoall(AlltoallOptions& opts);
0074 
0075 } // namespace gloo