@@ -170,6 +170,8 @@ void RTDeviceBinaryImage::init(sycl_device_binary Bin) {
170170 // it when invoking the offload wrapper job
171171 Format = static_cast <ur::DeviceBinaryType>(Bin->Format );
172172
173+ // For compressed images, we delay determining the format until the image is
174+ // decompressed.
173175 if (Format == SYCL_DEVICE_BINARY_TYPE_NONE)
174176 // try to determine the format; may remain "NONE"
175177 Format = ur::getBinaryImageFormat (Bin->BinaryStart , getSize ());
@@ -186,7 +188,6 @@ void RTDeviceBinaryImage::init(sycl_device_binary Bin) {
186188 ProgramMetadataUR.push_back (
187189 ur::mapDeviceBinaryPropertyToProgramMetadata (Prop));
188190 }
189-
190191 ExportedSymbols.init (Bin, __SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS);
191192 ImportedSymbols.init (Bin, __SYCL_PROPERTY_SET_SYCL_IMPORTED_SYMBOLS);
192193 DeviceGlobals.init (Bin, __SYCL_PROPERTY_SET_SYCL_DEVICE_GLOBALS);
@@ -235,25 +236,34 @@ CompressedRTDeviceBinaryImage::CompressedRTDeviceBinaryImage(
235236 sycl_device_binary CompressedBin)
236237 : RTDeviceBinaryImage() {
237238
238- size_t compressedDataSize = static_cast <size_t >(CompressedBin->BinaryEnd -
239- CompressedBin->BinaryStart );
239+ // 'CompressedBin' is part of the executable image loaded into memory
240+ // which can't be modified easily. So, we need to make a copy of it.
241+ Bin = new sycl_device_binary_struct (*CompressedBin);
242+
243+ // Get the decompressed size of the binary image.
244+ m_ImageSize = ZSTDCompressor::GetDecompressedSize (
245+ reinterpret_cast <const char *>(Bin->BinaryStart ),
246+ static_cast <size_t >(Bin->BinaryEnd - Bin->BinaryStart ));
247+
248+ init (Bin);
249+ }
250+
251+ void CompressedRTDeviceBinaryImage::Decompress () {
252+
253+ size_t CompressedDataSize =
254+ static_cast <size_t >(Bin->BinaryEnd - Bin->BinaryStart );
240255
241256 size_t DecompressedSize = 0 ;
242257 m_DecompressedData = ZSTDCompressor::DecompressBlob (
243- reinterpret_cast <const char *>(CompressedBin ->BinaryStart ),
244- compressedDataSize, DecompressedSize);
258+ reinterpret_cast <const char *>(Bin ->BinaryStart ), CompressedDataSize ,
259+ DecompressedSize);
245260
246- Bin = new sycl_device_binary_struct (*CompressedBin);
247261 Bin->BinaryStart =
248262 reinterpret_cast <const unsigned char *>(m_DecompressedData.get ());
249263 Bin->BinaryEnd = Bin->BinaryStart + DecompressedSize;
250264
251- // Set the new format to none and let RT determine the format.
252- // TODO: Add support for automatically detecting compressed
253- // binary format.
254- Bin->Format = SYCL_DEVICE_BINARY_TYPE_NONE;
255-
256- init (Bin);
265+ Bin->Format = ur::getBinaryImageFormat (Bin->BinaryStart , getSize ());
266+ Format = static_cast <ur::DeviceBinaryType>(Bin->Format );
257267}
258268
259269CompressedRTDeviceBinaryImage::~CompressedRTDeviceBinaryImage () {
0 commit comments