Skip to content

Commit

Permalink
dietgpu float32 compression support: part 1 (#8)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #8

This diff contains some prerequisites for losslessly compressing float32 data. We can now accept float32 dtype data, but right now the kernels that pack the uncompressed portion have not been adjusted to allow storing 3 bytes of data per word, so instead we store the uncompressed 3 bytes as a 4 byte word.

As a result, this diff does not actually compress float32 data, but instead expands it. A subsequent diff will either allow for the split/join float kernel to handle 3 byte words, or perhaps it will be more efficient to allow interleaving 2 ANS compressors in the same data stream using different table data, which would allow compressing 2 byte words of data with statistics independently computed for each byte.

This diff also removes the "min symbol" usage from the compressor and decompressor. The original idea for this was to perhaps allow for optimizations when there are <= 32 or 64 symbols by storing table data in registers, but this proved impractical. Additionally only the symbols that are present need be stored in the ANS header data, but entries for all 256 possible symbols are currently stored. We still compute the min and num symbols which can bound the data being stored, but we aren't doing anything with them. A subsequent diff might allow for storing a truncated symbol table if all 256 possible symbols are not in fact encountered. Removing this min symbol usage provides a small speedup and will make it easier to interleave 2 ANS compressors in the same stream by removing one less thing that needs to be handled.

Reviewed By: jspark1105

Differential Revision: D34157690

fbshipit-source-id: 64e2c8d6f5ab3acdb3441ed62717cb1638fb21e3
  • Loading branch information
wickedfoo authored and facebook-github-bot committed Mar 9, 2022
1 parent d089eaa commit 8f95d1c
Show file tree
Hide file tree
Showing 11 changed files with 544 additions and 380 deletions.
18 changes: 14 additions & 4 deletions dietgpu/DietGpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@ FloatType getFloatTypeFromDtype(at::ScalarType t) {
return FloatType::kFloat16;
case at::ScalarType::BFloat16:
return FloatType::kBFloat16;
case at::ScalarType::Float:
return FloatType::kFloat32;
default:
TORCH_CHECK(t == at::ScalarType::Half || t == at::ScalarType::BFloat16);
TORCH_CHECK(
t == at::ScalarType::Half || t == at::ScalarType::BFloat16 ||
t == at::ScalarType::Float);
return FloatType::kUndefined;
}
}
Expand All @@ -36,8 +40,12 @@ at::ScalarType getDtypeFromFloatType(FloatType ft) {
return at::ScalarType::Half;
case FloatType::kBFloat16:
return at::ScalarType::BFloat16;
case FloatType::kFloat32:
return at::ScalarType::Float;
default:
TORCH_CHECK(ft == FloatType::kFloat16 || ft == FloatType::kBFloat16);
TORCH_CHECK(
ft == FloatType::kFloat16 || ft == FloatType::kBFloat16 ||
ft == FloatType::kFloat32);
return at::ScalarType::Half;
}
}
Expand Down Expand Up @@ -535,7 +543,8 @@ int64_t decompress_data_res(
TORCH_CHECK(tIn.dtype() == torch::kByte);
if (compressAsFloat) {
TORCH_CHECK(
tOut.dtype() == torch::kFloat16 || tOut.dtype() == torch::kBFloat16);
tOut.dtype() == torch::kFloat16 || tOut.dtype() == torch::kBFloat16 ||
tOut.dtype() == torch::kFloat32);
}

inPtrs[i] = tIn.data_ptr();
Expand Down Expand Up @@ -692,7 +701,8 @@ int64_t decompress_data_split_size(
TORCH_CHECK(tOut.is_contiguous());
if (compressAsFloat) {
TORCH_CHECK(
tOut.dtype() == torch::kFloat16 || tOut.dtype() == torch::kBFloat16);
tOut.dtype() == torch::kFloat16 || tOut.dtype() == torch::kBFloat16 ||
tOut.dtype() == torch::kFloat32);
}

auto outSize =
Expand Down
10 changes: 5 additions & 5 deletions dietgpu/ans/GpuANSDecode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ __global__ void ansDecodeTable(
auto probs = headerIn->getSymbolProbs();

static_assert(Threads >= kNumSymbols, "");
uint32_t pdf = tid < numSymbols ? probs[tid] : 0;
uint32_t pdf = tid < kNumSymbols ? probs[tid] : 0;
uint32_t cdf = 0;

// Get the CDF from the PDF
Expand All @@ -418,7 +418,7 @@ __global__ void ansDecodeTable(
// Broadcast the pdf/cdf values
__shared__ uint2 smemPdfCdf[kNumSymbols];

if (tid < numSymbols) {
if (tid < kNumSymbols) {
smemPdfCdf[tid] = uint2{pdf, cdf};
}

Expand All @@ -427,16 +427,16 @@ __global__ void ansDecodeTable(
// Build the table for each pdf/cdf bucket
constexpr int kWarpsPerBlock = Threads / kWarpSize;

for (int i = warpId; i < numSymbols; i += kWarpsPerBlock) {
for (int i = warpId; i < kNumSymbols; i += kWarpsPerBlock) {
auto v = smemPdfCdf[i];

auto pdf = v.x;
auto begin = v.y;
auto end = (i + 1) < numSymbols ? (begin + pdf) : totalProb;
auto end = begin + pdf;

for (int j = begin + laneId; j < end; j += kWarpSize) {
table[j] = packDecodeLookup(
i + symbolOffset, // symbol
i, // symbol
pdf, // bucket pdf
j - begin); // within-bucket cdf
}
Expand Down
68 changes: 8 additions & 60 deletions dietgpu/ans/GpuANSEncode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ __device__ uint32_t ansEncodeWarpBlock(
const ANSDecodedT* __restrict__ in,
// Number of ANSDecodedT words in this block
uint32_t inWords,
// Symbol offset
ANSDecodedT minSymbol,
// encoded table in smem
const uint4* __restrict__ table,
// Output for this block
Expand All @@ -161,7 +159,7 @@ __device__ uint32_t ansEncodeWarpBlock(

constexpr int kUnroll = 8;

// Unrolled iterationsxs
// Unrolled iterations
uint32_t limit = roundDown(inWords, kWarpSize * kUnroll);
{
ANSDecodedT sym[kUnroll];
Expand All @@ -172,11 +170,6 @@ __device__ uint32_t ansEncodeWarpBlock(
sym[j] = in[inOffset + j * kWarpSize];
}

#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
sym[j] -= minSymbol;
}

#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
outOffset +=
Expand All @@ -191,7 +184,7 @@ __device__ uint32_t ansEncodeWarpBlock(

// Whole warp iterations
for (; inOffset < limit; inOffset += kWarpSize) {
ANSDecodedT sym = in[inOffset] - minSymbol;
ANSDecodedT sym = in[inOffset];

outOffset +=
encodeOneWarp<ProbBits>(state, sym, outOffset, outWords, table);
Expand All @@ -201,7 +194,7 @@ __device__ uint32_t ansEncodeWarpBlock(
if (limit != inWords) {
// Last iteration may not be a full warp
bool valid = inOffset < inWords;
ANSDecodedT sym = valid ? in[inOffset] - minSymbol : ANSDecodedT(0);
ANSDecodedT sym = valid ? in[inOffset] : ANSDecodedT(0);

outOffset += encodeOnePartialWarp<ProbBits>(
valid, state, sym, outOffset, outWords, table);
Expand All @@ -224,8 +217,6 @@ __device__ uint32_t ansEncodeWarpFullBlock(
uint32_t laneId,
// Input for this block
const ANSDecodedT* __restrict__ in,
// Symbol offset
ANSDecodedT minSymbol,
// encoded table in smem
const uint4* __restrict__ table,
// Output for this block
Expand All @@ -238,14 +229,6 @@ __device__ uint32_t ansEncodeWarpFullBlock(

uint32_t outOffset = 0;

uint32_t minSymbolV = minSymbol;
minSymbolV <<= 8;
minSymbolV |= minSymbol;
minSymbolV <<= 8;
minSymbolV |= minSymbol;
minSymbolV <<= 8;
minSymbolV |= minSymbol;

using VecT = uint32_t;

auto inV = (const VecT*)in;
Expand All @@ -268,11 +251,6 @@ __device__ uint32_t ansEncodeWarpFullBlock(
symV[j] = inV[j * kWarpSize];
}

#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
symV[j] -= minSymbolV;
}

#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
asm volatile("prefetch.global.L2 [%0];" : : "l"(inV + 128 + j * 32));
Expand Down Expand Up @@ -307,10 +285,6 @@ __device__ void ansEncodeBlocksFull(
uint8_t* __restrict__ out,
// output array of per-block sizes of number of ANSEncodedT words per block
uint32_t* __restrict__ compressedWords,
// for the range of ANSDecodedT, the minimum symbol and number of symbols
// that we actually encounter
uint32_t minSymbol,
uint32_t numSymbols,
// the encoding table that we will load into smem
const uint4* __restrict__ table) {
// grid-wide warp id
Expand All @@ -323,7 +297,7 @@ __device__ void ansEncodeBlocksFull(
__shared__ uint4 smemLookup[kNumSymbols];

// we always have at least 256 threads
if (tid < numSymbols) {
if (tid < kNumSymbols) {
smemLookup[tid] = table[tid];
}

Expand All @@ -348,7 +322,7 @@ __device__ void ansEncodeBlocksFull(
assert(isPointerAligned(inBlock, kANSRequiredAlignment));

auto outWords = ansEncodeWarpFullBlock<ProbBits, BlockSize>(
laneId, inBlock, minSymbol, smemLookup, outBlock);
laneId, inBlock, smemLookup, outBlock);

if (laneId == 0) {
// If the bound on max compressed size is not correct, this assert will go
Expand All @@ -374,10 +348,6 @@ __device__ void ansEncodeBlocksPartial(
uint8_t* __restrict__ out,
// output array of per-block sizes of number of ANSEncodedT words per block
uint32_t* __restrict__ compressedWords,
// for the range of ANSDecodedT, the minimum symbol and number of symbols
// that we actually encounter
uint32_t minSymbol,
uint32_t numSymbols,
// the encoding table that we will load into smem
const uint4* __restrict__ table) {
int block = numBlocks - 1;
Expand All @@ -387,7 +357,7 @@ __device__ void ansEncodeBlocksPartial(
__shared__ uint4 smemLookup[kNumSymbols];

// we always have at least 256 threads
if (tid < numSymbols) {
if (tid < kNumSymbols) {
smemLookup[tid] = table[tid];
}

Expand Down Expand Up @@ -417,7 +387,7 @@ __device__ void ansEncodeBlocksPartial(
assert(isPointerAligned(inBlock, kANSRequiredAlignment));

auto outWords = ansEncodeWarpBlock<ProbBits>(
laneId, inBlock, blockSize, minSymbol, smemLookup, outBlock);
laneId, inBlock, blockSize, smemLookup, outBlock);

if (laneId == 0) {
// If the bound on max compressed size is not correct, this assert will go
Expand All @@ -443,9 +413,6 @@ __global__ void ansEncodeBatchFull(
// per batch
// [batch][numBlocks]
uint32_t* __restrict__ compressedWords,
// for the range of ANSDecodedT, the minimum symbol and number of symbols
// that we actually encounter [batch]
const uint2* __restrict__ minAndNumSymbols,
// the encoding table that we will load into smem
// [batch][kNumSymbols]
const uint4* __restrict__ table) {
Expand All @@ -456,19 +423,13 @@ __global__ void ansEncodeBatchFull(
uint32_t curSize = inProvider.getBatchSize(batch);
uint32_t numBlocks = divUp(curSize, BlockSize);

auto minNumSym = minAndNumSymbols[batch];
uint32_t minSymbol = minNumSym.x;
uint32_t numSymbols = minNumSym.y;

ansEncodeBlocksFull<ProbBits, BlockSize>(
(const ANSDecodedT*)inProvider.getBatchStart(batch),
curSize,
numBlocks,
maxCompressedBlockSize,
out + batch * maxNumCompressedBlocks * maxCompressedBlockSize,
compressedWords + batch * maxNumCompressedBlocks,
minSymbol,
numSymbols,
table + batch * kNumSymbols);
}

Expand All @@ -486,9 +447,6 @@ __global__ void ansEncodeBatchPartial(
// per batch
// [batch][numBlocks]
uint32_t* __restrict__ compressedWords,
// for the range of ANSDecodedT, the minimum symbol and number of symbols
// that we actually encounter [batch]
const uint2* __restrict__ minAndNumSymbols,
// the encoding table that we will load into smem
// [batch][kNumSymbols]
const uint4* __restrict__ table) {
Expand All @@ -499,19 +457,13 @@ __global__ void ansEncodeBatchPartial(
uint32_t curSize = inProvider.getBatchSize(batch);
uint32_t numBlocks = divUp(curSize, BlockSize);

auto minNumSym = minAndNumSymbols[batch];
uint32_t minSymbol = minNumSym.x;
uint32_t numSymbols = minNumSym.y;

ansEncodeBlocksPartial<ProbBits, BlockSize>(
(const ANSDecodedT*)inProvider.getBatchStart(batch),
inProvider.getBatchSize(batch),
numBlocks,
maxCompressedBlockSize,
out + batch * maxNumCompressedBlocks * maxCompressedBlockSize,
compressedWords + batch * maxNumCompressedBlocks,
minSymbol,
numSymbols,
table + batch * kNumSymbols);
}

Expand Down Expand Up @@ -571,8 +523,6 @@ __device__ void ansEncodeCoalesce(
}

ANSCoalescedHeader header;
// printf("setting num blocks %u probBits %u minSymbol %u numSym %u\n",
// numBlocks, probBits, minSymbol, numSymbols);
header.setNumBlocks(numBlocks);
header.setUncompressedWords(uncompressedWords);
header.setCompressedWords(totalCompressedWords);
Expand All @@ -588,7 +538,7 @@ __device__ void ansEncodeCoalesce(
auto probsOut = headerOut->getSymbolProbs();

// Write out pdf
for (int i = tid; i < numSymbols; i += Threads) {
for (int i = tid; i < kNumSymbols; i += Threads) {
probsOut[i] = table[i].x;
}
}
Expand Down Expand Up @@ -769,7 +719,6 @@ void ansEncodeBatchDevice(
uncoalescedBlockStride, \
compressedBlocks_dev.data(), \
compressedWords_dev.data(), \
minAndNumSymbols_dev.data(), \
table_dev.data()); \
\
ansEncodeBatchPartial<InProvider, BITS, kDefaultBlockSize> \
Expand All @@ -779,7 +728,6 @@ void ansEncodeBatchDevice(
uncoalescedBlockStride, \
compressedBlocks_dev.data(), \
compressedWords_dev.data(), \
minAndNumSymbols_dev.data(), \
table_dev.data());

switch (probBits) {
Expand Down
Loading

0 comments on commit 8f95d1c

Please sign in to comment.