astc: Return zero on out of bound bits

Avoid out of bound reads on invalid ASTC textures.
Games can bind invalid textures that make us read or write out of bounds.
This commit is contained in:
ReinUsesLisp 2021-01-15 02:15:04 -03:00
parent 93f7719eed
commit 0ec71b78fb

View file

@ -42,21 +42,24 @@ constexpr u32 Popcnt(u32 n) {
class InputBitStream { class InputBitStream {
public: public:
constexpr explicit InputBitStream(const u8* ptr, std::size_t start_offset = 0) constexpr explicit InputBitStream(std::span<const u8> data, size_t start_offset = 0)
: cur_byte{ptr}, next_bit{start_offset % 8} {} : cur_byte{data.data()}, total_bits{data.size()}, next_bit{start_offset % 8} {}
constexpr std::size_t GetBitsRead() const { constexpr size_t GetBitsRead() const {
return bits_read; return bits_read;
} }
constexpr bool ReadBit() { constexpr bool ReadBit() {
const bool bit = (*cur_byte >> next_bit++) & 1; if (bits_read >= total_bits * 8) {
return 0;
}
const bool bit = ((*cur_byte >> next_bit) & 1) != 0;
++next_bit;
while (next_bit >= 8) { while (next_bit >= 8) {
next_bit -= 8; next_bit -= 8;
cur_byte++; ++cur_byte;
} }
++bits_read;
bits_read++;
return bit; return bit;
} }
@ -79,8 +82,9 @@ public:
private: private:
const u8* cur_byte; const u8* cur_byte;
std::size_t next_bit = 0; size_t total_bits = 0;
std::size_t bits_read = 0; size_t next_bit = 0;
size_t bits_read = 0;
}; };
class OutputBitStream { class OutputBitStream {
@ -200,8 +204,8 @@ using IntegerEncodedVector = boost::container::static_vector<
static void DecodeTritBlock(InputBitStream& bits, IntegerEncodedVector& result, u32 nBitsPerValue) { static void DecodeTritBlock(InputBitStream& bits, IntegerEncodedVector& result, u32 nBitsPerValue) {
// Implement the algorithm in section C.2.12 // Implement the algorithm in section C.2.12
u32 m[5]; std::array<u32, 5> m;
u32 t[5]; std::array<u32, 5> t;
u32 T; u32 T;
// Read the trit encoded block according to // Read the trit encoded block according to
@ -866,7 +870,7 @@ public:
} }
}; };
static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nPartitions, static void DecodeColorValues(u32* out, std::span<u8> data, const u32* modes, const u32 nPartitions,
const u32 nBitsForColorData) { const u32 nBitsForColorData) {
// First figure out how many color values we have // First figure out how many color values we have
u32 nValues = 0; u32 nValues = 0;
@ -898,7 +902,7 @@ static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nP
// We now have enough to decode our integer sequence. // We now have enough to decode our integer sequence.
IntegerEncodedVector decodedColorValues; IntegerEncodedVector decodedColorValues;
InputBitStream colorStream(data); InputBitStream colorStream(data, 0);
DecodeIntegerSequence(decodedColorValues, colorStream, range, nValues); DecodeIntegerSequence(decodedColorValues, colorStream, range, nValues);
// Once we have the decoded values, we need to dequantize them to the 0-255 range // Once we have the decoded values, we need to dequantize them to the 0-255 range
@ -1441,7 +1445,7 @@ static void ComputeEndpos32s(Pixel& ep1, Pixel& ep2, const u32*& colorValues,
static void DecompressBlock(std::span<const u8, 16> inBuf, const u32 blockWidth, static void DecompressBlock(std::span<const u8, 16> inBuf, const u32 blockWidth,
const u32 blockHeight, std::span<u32, 12 * 12> outBuf) { const u32 blockHeight, std::span<u32, 12 * 12> outBuf) {
InputBitStream strm(inBuf.data()); InputBitStream strm(inBuf);
TexelWeightParams weightParams = DecodeBlockInfo(strm); TexelWeightParams weightParams = DecodeBlockInfo(strm);
// Was there an error? // Was there an error?
@ -1619,15 +1623,16 @@ static void DecompressBlock(std::span<const u8, 16> inBuf, const u32 blockWidth,
// Make sure that higher non-texel bits are set to zero // Make sure that higher non-texel bits are set to zero
const u32 clearByteStart = (weightParams.GetPackedBitSize() >> 3) + 1; const u32 clearByteStart = (weightParams.GetPackedBitSize() >> 3) + 1;
if (clearByteStart > 0) { if (clearByteStart > 0 && clearByteStart <= texelWeightData.size()) {
texelWeightData[clearByteStart - 1] &= texelWeightData[clearByteStart - 1] &=
static_cast<u8>((1 << (weightParams.GetPackedBitSize() % 8)) - 1); static_cast<u8>((1 << (weightParams.GetPackedBitSize() % 8)) - 1);
std::memset(texelWeightData.data() + clearByteStart, 0,
std::min(16U - clearByteStart, 16U));
} }
std::memset(texelWeightData.data() + clearByteStart, 0, std::min(16U - clearByteStart, 16U));
IntegerEncodedVector texelWeightValues; IntegerEncodedVector texelWeightValues;
InputBitStream weightStream(texelWeightData.data()); InputBitStream weightStream(texelWeightData);
DecodeIntegerSequence(texelWeightValues, weightStream, weightParams.m_MaxWeight, DecodeIntegerSequence(texelWeightValues, weightStream, weightParams.m_MaxWeight,
weightParams.GetNumWeightValues()); weightParams.GetNumWeightValues());