Skip to content

Commit c1a2c13

Browse files
committed
Delay image decompression till it is actually used.
1 parent 970ad35 commit c1a2c13

File tree

6 files changed

+79
-29
lines changed

6 files changed

+79
-29
lines changed

clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,10 +1112,11 @@ class BinaryWrapper {
11121112
// If '--offload-compress' option is specified and zstd is not
11131113
// available, throw an error.
11141114
if (OffloadCompressDevImgs && !llvm::compression::zstd::isAvailable()) {
1115-
createStringError(inconvertibleErrorCode(),
1116-
"'--offload-compress' option is specified but zstd "
1117-
"is not available. The device image will not be "
1118-
"compressed.");
1115+
return createStringError(
1116+
inconvertibleErrorCode(),
1117+
"'--offload-compress' option is specified but zstd "
1118+
"is not available. The device image will not be "
1119+
"compressed.");
11191120
}
11201121

11211122
// Don't compress if the user explicitly specifies the binary image

sycl/source/detail/compression.hpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,17 @@ class ZSTDCompressor {
8181
return dstBuffer;
8282
}
8383

84+
static size_t GetDecompressedSize(const char *src, size_t srcSize) {
85+
size_t dstBufferSize = ZSTD_getFrameContentSize(src, srcSize);
86+
87+
if (dstBufferSize == ZSTD_CONTENTSIZE_UNKNOWN ||
88+
dstBufferSize == ZSTD_CONTENTSIZE_ERROR) {
89+
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
90+
"Error determining size of uncompressed data.");
91+
}
92+
return dstBufferSize;
93+
}
94+
8495
static std::unique_ptr<char> DecompressBlob(const char *src, size_t srcSize,
8596
size_t &dstSize) {
8697
auto &instance = GetSingletonInstance();
@@ -101,13 +112,7 @@ class ZSTDCompressor {
101112

102113
// Size of decompressed image can be larger than what we can allocate
103114
// on heap. In that case, we need to use streaming decompression.
104-
auto dstBufferSize = ZSTD_getFrameContentSize(src, srcSize);
105-
106-
if (dstBufferSize == ZSTD_CONTENTSIZE_UNKNOWN ||
107-
dstBufferSize == ZSTD_CONTENTSIZE_ERROR) {
108-
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
109-
"Error determining size of uncompressed data.");
110-
}
115+
auto dstBufferSize = GetDecompressedSize(src, srcSize);
111116

112117
// Allocate buffer for decompressed data.
113118
auto dstBuffer = std::unique_ptr<char>(new char[dstBufferSize]);

sycl/source/detail/device_binary_image.cpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ void RTDeviceBinaryImage::init(sycl_device_binary Bin) {
170170
// it when invoking the offload wrapper job
171171
Format = static_cast<ur::DeviceBinaryType>(Bin->Format);
172172

173+
// For compressed images, we delay determining the format until the image is
174+
// decompressed.
173175
if (Format == SYCL_DEVICE_BINARY_TYPE_NONE)
174176
// try to determine the format; may remain "NONE"
175177
Format = ur::getBinaryImageFormat(Bin->BinaryStart, getSize());
@@ -186,7 +188,6 @@ void RTDeviceBinaryImage::init(sycl_device_binary Bin) {
186188
ProgramMetadataUR.push_back(
187189
ur::mapDeviceBinaryPropertyToProgramMetadata(Prop));
188190
}
189-
190191
ExportedSymbols.init(Bin, __SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS);
191192
ImportedSymbols.init(Bin, __SYCL_PROPERTY_SET_SYCL_IMPORTED_SYMBOLS);
192193
DeviceGlobals.init(Bin, __SYCL_PROPERTY_SET_SYCL_DEVICE_GLOBALS);
@@ -235,25 +236,34 @@ CompressedRTDeviceBinaryImage::CompressedRTDeviceBinaryImage(
235236
sycl_device_binary CompressedBin)
236237
: RTDeviceBinaryImage() {
237238

238-
size_t compressedDataSize = static_cast<size_t>(CompressedBin->BinaryEnd -
239-
CompressedBin->BinaryStart);
239+
// 'CompressedBin' is part of the executable image loaded into memory
240+
// which can't be modified easily. So, we need to make a copy of it.
241+
Bin = new sycl_device_binary_struct(*CompressedBin);
242+
243+
// Get the decompressed size of the binary image.
244+
m_ImageSize = ZSTDCompressor::GetDecompressedSize(
245+
reinterpret_cast<const char *>(Bin->BinaryStart),
246+
static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart));
247+
248+
init(Bin);
249+
}
250+
251+
void CompressedRTDeviceBinaryImage::Decompress() {
252+
253+
size_t CompressedDataSize =
254+
static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart);
240255

241256
size_t DecompressedSize = 0;
242257
m_DecompressedData = ZSTDCompressor::DecompressBlob(
243-
reinterpret_cast<const char *>(CompressedBin->BinaryStart),
244-
compressedDataSize, DecompressedSize);
258+
reinterpret_cast<const char *>(Bin->BinaryStart), CompressedDataSize,
259+
DecompressedSize);
245260

246-
Bin = new sycl_device_binary_struct(*CompressedBin);
247261
Bin->BinaryStart =
248262
reinterpret_cast<const unsigned char *>(m_DecompressedData.get());
249263
Bin->BinaryEnd = Bin->BinaryStart + DecompressedSize;
250264

251-
// Set the new format to none and let RT determine the format.
252-
// TODO: Add support for automatically detecting compressed
253-
// binary format.
254-
Bin->Format = SYCL_DEVICE_BINARY_TYPE_NONE;
255-
256-
init(Bin);
265+
Bin->Format = ur::getBinaryImageFormat(Bin->BinaryStart, getSize());
266+
Format = static_cast<ur::DeviceBinaryType>(Bin->Format);
257267
}
258268

259269
CompressedRTDeviceBinaryImage::~CompressedRTDeviceBinaryImage() {

sycl/source/detail/device_binary_image.hpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,10 @@ class RTDeviceBinaryImage {
158158
virtual void print() const;
159159
virtual void dump(std::ostream &Out) const;
160160

161-
size_t getSize() const {
161+
// getSize will be overridden in the case of compressed binary images.
162+
// In that case, we return the size of uncompressed data, instead of
163+
// BinaryEnd - BinaryStart.
164+
virtual size_t getSize() const {
162165
assert(Bin && "binary image data not set");
163166
return static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart);
164167
}
@@ -277,21 +280,31 @@ class DynRTDeviceBinaryImage : public RTDeviceBinaryImage {
277280
};
278281

279282
#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE
280-
// Compressed device binary image. It decompresses the binary image on
281-
// construction and stores the decompressed data as RTDeviceBinaryImage.
283+
// Compressed device binary image. Decompression happens when the image is
284+
// actually used to build a program.
282285
// Also, frees the decompressed data in destructor.
283286
class CompressedRTDeviceBinaryImage : public RTDeviceBinaryImage {
284287
public:
285288
CompressedRTDeviceBinaryImage(sycl_device_binary Bin);
286289
~CompressedRTDeviceBinaryImage() override;
287290

291+
void Decompress();
292+
293+
// We return the size of decompressed data, not the size of compressed data.
294+
size_t getSize() const override {
295+
assert(Bin && "binary image data not set");
296+
return m_ImageSize;
297+
}
298+
299+
bool IsCompressed() const { return m_DecompressedData.get() == nullptr; }
288300
void print() const override {
289301
RTDeviceBinaryImage::print();
290302
std::cerr << " COMPRESSED\n";
291303
}
292304

293305
private:
294306
std::unique_ptr<char> m_DecompressedData;
307+
size_t m_ImageSize;
295308
};
296309
#endif // SYCL_RT_ZSTD_NOT_AVAIABLE
297310

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,12 @@ setSpecializationConstants(const std::shared_ptr<device_image_impl> &InputImpl,
733733
}
734734
}
735735

736+
static inline void CheckAndDecompressImage(RTDeviceBinaryImage *Img) {
737+
if (auto CompImg = dynamic_cast<CompressedRTDeviceBinaryImage *>(Img))
738+
if (CompImg->IsCompressed())
739+
CompImg->Decompress();
740+
}
741+
736742
// When caching is enabled, the returned UrProgram will already have
737743
// its ref count incremented.
738744
ur_program_handle_t ProgramManager::getBuiltURProgram(
@@ -785,6 +791,10 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
785791
collectDeviceImageDepsForImportedSymbols(Img, Device);
786792
DeviceImagesToLink.insert(ImageDeps.begin(), ImageDeps.end());
787793

794+
// Decompress all DeviceImagesToLink
795+
for (RTDeviceBinaryImage *BinImg : DeviceImagesToLink)
796+
CheckAndDecompressImage(BinImg);
797+
788798
std::vector<const RTDeviceBinaryImage *> AllImages;
789799
AllImages.reserve(ImageDeps.size() + 1);
790800
AllImages.push_back(&Img);
@@ -1388,6 +1398,10 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
13881398
Device);
13891399
}
13901400
}
1401+
1402+
// Decompress the image if it is compressed.
1403+
CheckAndDecompressImage(Img);
1404+
13911405
if (Img) {
13921406
CheckJITCompilationForImage(Img, JITCompilationIsRequired);
13931407

@@ -1714,6 +1728,10 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
17141728
[&](auto &CurrentImg) {
17151729
return CurrentImg.first->getFormat() == Img->getFormat();
17161730
});
1731+
1732+
// Check if image is compressed, and decompress it before dumping.
1733+
CheckAndDecompressImage(Img.get());
1734+
17171735
dumpImage(*Img, NeedsSequenceID ? ++SequenceID : 0);
17181736
}
17191737

@@ -2191,6 +2209,9 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
21912209

21922210
auto &[KernelImagesState, KernelImages] = *StateImagesPair;
21932211

2212+
// Check if device image is compressed and decompress it if needed
2213+
CheckAndDecompressImage(BinImage);
2214+
21942215
if (KernelImages.empty()) {
21952216
KernelImagesState = ImgState;
21962217
KernelImages.push_back(BinImage);

sycl/test-e2e/Compression/compression_multiple_tu.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
// REQUIRES: zstd
44

55
// DEFINE: %{fPIC_flag} = %if windows %{%} %else %{-fPIC%}
6-
// RUN: %{build} --offload-compress -DENABLE_KERNEL1 -shared %{fPIC_flag} -o %t_kernel1.so
7-
// RUN: %{build} -DENABLE_KERNEL2 -shared %{fPIC_flag} -o %t_kernel2.so
6+
// RUN: %{build} --offload-compress -DENABLE_KERNEL1 -shared %{fPIC_flag} -o %T/kernel1.so
7+
// RUN: %{build} -DENABLE_KERNEL2 -shared %{fPIC_flag} -o %T/kernel2.so
88

9-
// RUN: %{build} %t_kernel1.so %t_kernel2.so -Wl,-rpath=%T -o %t_compress.out
9+
// RUN: %{build} %t_kernel1.so %t_kernel2.so -o %t_compress.out
1010
// RUN: %{run} %t_compress.out
1111
#if defined(ENABLE_KERNEL1) || defined(ENABLE_KERNEL2)
1212
#include <sycl/builtins.hpp>

0 commit comments

Comments
 (0)