Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calng-format #680

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/models/debugging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void DumpValues(std::ostream& stream, ONNXTensorElementDataType type, const void

switch (type) {
case Ort::TypeToTensorType<bool>::type:
DumpSpan(stream, std::span<const bool>{reinterpret_cast<const bool*>(p_values_raw), count});
DumpSpan(stream, std::span<const bool>(reinterpret_cast<const bool*>(p_values_raw), count));
break;

case Ort::TypeToTensorType<int8_t>::type:
Expand Down
80 changes: 65 additions & 15 deletions src/span.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,90 @@
#ifndef USE_CXX17
#include <span>
#else

#include <array>
#include <stdexcept>
#include <vector>
namespace std {

namespace std {
namespace generators_span {
template <typename T>
struct span {
span() = default;
span(T* p, size_t length) : p_{p}, length_{length} {}
constexpr span(T* p, size_t length) noexcept : p_{p}, length_{length} {}
constexpr span(const span<T>& s) noexcept : p_{s.p_}, length_{s.length_} {}

span(const span<std::remove_const_t<T> >& s) : p_{const_cast<T*>(s.data())}, length_{s.size()} {}
span(const std::vector<std::remove_const_t<T> >& s) : p_{const_cast<T*>(s.data())}, length_{s.size()} {}
span(std::vector<T>& s) noexcept : p_{s.data()}, length_{s.size()} {}
template <size_t N>
span(const std::array<std::remove_const_t<T>, N>& s) : p_{const_cast<T*>(s.data())}, length_{s.size()} {}
constexpr span(std::array<T, N>& s) noexcept : p_{s.data()}, length_{s.size()} {}

bool empty() const { return length_ == 0; }
constexpr bool empty() const noexcept { return length_ == 0; }

size_t size() const { return length_; }
size_t size_bytes() const { return length_ * sizeof(T); }
constexpr size_t size() const noexcept { return length_; }
constexpr size_t size_bytes() const noexcept { return length_ * sizeof(T); }

T* data() { return p_; }
const T* data() const { return p_; }
constexpr T* data() const noexcept { return p_; }

T& operator[](size_t index) const { return p_[index]; }
constexpr T& operator[](size_t index) const noexcept { return p_[index]; }

T& back() const { return p_[length_ - 1]; }
constexpr T& back() const noexcept { return p_[length_ - 1]; }

T* begin() const { return p_; }
T* end() const { return p_ + length_; }
constexpr T* begin() const noexcept { return p_; }
constexpr T* end() const noexcept { return p_ + length_; }

span subspan(size_t index, size_t length) const { return span(p_ + index, length); }
span subspan(size_t index, size_t length) const {
if (index >= length_ || index + length > length_)
throw std::out_of_range("Requested subspan is out of range");

return span(p_ + index, length);
}

private:
T* p_{};
size_t length_{};
};

template <class T>
struct span<const T> {
span() = default;
constexpr span(const T* p, size_t length) noexcept : p_{p}, length_{length} {}

constexpr span(const span<const T>& s) noexcept : p_{s.p_}, length_{s.length_} {}
constexpr span(const span<T>& s) noexcept : p_{s.data()}, length_{s.size()} {}

span(const std::vector<T>& s) noexcept : p_{s.data()}, length_{s.size()} {}
template <size_t N>
constexpr span(const std::array<T, N>& s) noexcept : p_{s.data()}, length_{s.size()} {}

constexpr bool empty() const noexcept { return length_ == 0; }

constexpr size_t size() const noexcept { return length_; }
constexpr size_t size_bytes() const noexcept { return length_ * sizeof(T); }

constexpr const T* data() const noexcept { return p_; }

constexpr const T& operator[](size_t index) const noexcept { return p_[index]; }

constexpr const T& back() const noexcept { return p_[length_ - 1]; }

constexpr const T* begin() const noexcept { return p_; }
constexpr const T* end() const noexcept { return p_ + length_; }

span subspan(size_t index, size_t length) const {
if (index >= length_ || index + length > length_)
throw std::out_of_range("Requested subspan is out of range");

return span(p_ + index, length);
}

private:
const T* p_{};
size_t length_{};
};
} // namespace generators_span

using generators_span::span;

} // namespace std

#endif
2 changes: 1 addition & 1 deletion test/sampling_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ TEST(Benchmarks, BenchmarkRandomizedSelectTopCuda) {
auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
int vocab_size = 32000; // vocab size of llama
int batch_size = 12;
std::vector<int32_t> input_ids{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; // Needs to match batch_size
std::vector<int32_t> input_ids{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; // Needs to match batch_size
auto params = Generators::CreateGeneratorParams();
params->search.max_length = 10;
params->batch_size = batch_size;
Expand Down