diff options
| -rw-r--r-- | src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp | 149 | 
1 files changed, 80 insertions, 69 deletions
| diff --git a/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp b/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp index c8bd7b329..f94c82e21 100644 --- a/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp +++ b/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp @@ -4,9 +4,9 @@  #include <algorithm>  #include <compare> -#include <map>  #include <optional>  #include <ranges> +#include <queue>  #include <boost/container/flat_set.hpp>  #include <boost/container/small_vector.hpp> @@ -40,15 +40,19 @@ struct Bias {      u32 offset_end;  }; +using boost::container::flat_set; +using boost::container::small_vector;  using StorageBufferSet = -    boost::container::flat_set<StorageBufferAddr, std::less<StorageBufferAddr>, -                               boost::container::small_vector<StorageBufferAddr, 16>>; -using StorageInstVector = boost::container::small_vector<StorageInst, 24>; -using VisitedBlocks = boost::container::flat_set<IR::Block*, std::less<IR::Block*>, -                                                 boost::container::small_vector<IR::Block*, 4>>; +    flat_set<StorageBufferAddr, std::less<StorageBufferAddr>, small_vector<StorageBufferAddr, 16>>; +using StorageInstVector = small_vector<StorageInst, 24>;  using StorageWritesSet = -    boost::container::flat_set<StorageBufferAddr, std::less<StorageBufferAddr>, -                               boost::container::small_vector<StorageBufferAddr, 16>>; +    flat_set<StorageBufferAddr, std::less<StorageBufferAddr>, small_vector<StorageBufferAddr, 16>>; + +struct StorageInfo { +    StorageBufferSet set; +    StorageInstVector to_replace; +    StorageWritesSet writes; +};  /// Returns true when the instruction is a global memory instruction  bool IsGlobalMemory(const IR::Inst& inst) { @@ -215,60 +219,72 @@ std::optional<LowAddrInfo> TrackLowAddress(IR::Inst* inst) {      };  } -/// Recursively tries to track the storage buffer address used by a global memory instruction -std::optional<StorageBufferAddr> Track(IR::Block* block, const IR::Value& value, const Bias* bias, -                                       VisitedBlocks& visited) { +/// Tries to get the storage buffer out of a constant buffer read instruction +std::optional<StorageBufferAddr> TryGetStorageBuffer(const IR::Inst* inst, const Bias* bias) { +    if (inst->Opcode() != IR::Opcode::GetCbufU32) { +        return std::nullopt; +    } +    const IR::Value index{inst->Arg(0)}; +    const IR::Value offset{inst->Arg(1)}; +    if (!index.IsImmediate()) { +        // Definitely not a storage buffer if it's read from a non-immediate index +        return std::nullopt; +    } +    if (!offset.IsImmediate()) { +        // TODO: Support SSBO arrays +        return std::nullopt; +    } +    const StorageBufferAddr storage_buffer{ +        .index{index.U32()}, +        .offset{offset.U32()}, +    }; +    if (bias && !MeetsBias(storage_buffer, *bias)) { +        // We have to blacklist some addresses in case we wrongly point to them +        return std::nullopt; +    } +    return storage_buffer; +} + +/// Tries to track the storage buffer address used by a global memory instruction +std::optional<StorageBufferAddr> Track(const IR::Value& value, const Bias* bias) {      if (value.IsImmediate()) { -        // Immediates can't be a storage buffer +        // Nothing to do with immediates          return std::nullopt;      } -    const IR::Inst* const inst{value.InstRecursive()}; -    if (inst->Opcode() == IR::Opcode::GetCbufU32) { -        const IR::Value index{inst->Arg(0)}; -        const IR::Value offset{inst->Arg(1)}; -        if (!index.IsImmediate()) { -            // Definitely not a storage buffer if it's read from a non-immediate index -            return std::nullopt; -        } -        if (!offset.IsImmediate()) { -            // TODO: Support SSBO arrays -            return std::nullopt; -        } -        const StorageBufferAddr storage_buffer{ -            .index{index.U32()}, -            .offset{offset.U32()}, -        }; -        if (bias && !MeetsBias(storage_buffer, *bias)) { -            // We have to blacklist some addresses in case we wrongly point to them -            return std::nullopt; +    // Breadth-first search visiting the right most arguments first +    // Small vector has been determined from shaders in Super Smash Bros. Ultimate +    small_vector<const IR::Inst*, 2> visited; +    std::queue<const IR::Inst*> queue; +    queue.push(value.InstRecursive()); + +    while (!queue.empty()) { +        // Pop one instruction from the queue +        const IR::Inst* const inst{queue.front()}; +        queue.pop(); +        if (const std::optional<StorageBufferAddr> result = TryGetStorageBuffer(inst, bias)) { +            // This is the instruction we were looking for +            return result;          } -        return storage_buffer; -    } -    // Reversed loops are more likely to find the right result -    for (size_t arg = inst->NumArgs(); arg--;) { -        IR::Block* inst_block{block}; -        if (inst->Opcode() == IR::Opcode::Phi) { -            // If we are going through a phi node, mark the current block as visited -            visited.insert(block); -            // and skip already visited blocks to avoid looping forever -            IR::Block* const phi_block{inst->PhiBlock(arg)}; -            if (visited.contains(phi_block)) { -                // Already visited, skip +        // Visit the right most arguments first +        for (size_t arg = inst->NumArgs(); arg--;) { +            const IR::Value arg_value{inst->Arg(arg)}; +            if (arg_value.IsImmediate()) {                  continue;              } -            inst_block = phi_block; -        } -        const std::optional storage_buffer{Track(inst_block, inst->Arg(arg), bias, visited)}; -        if (storage_buffer) { -            return *storage_buffer; +            // Queue instruction if it hasn't been visited +            const IR::Inst* const arg_inst{arg_value.InstRecursive()}; +            if (std::ranges::find(visited, arg_inst) == visited.end()) { +                visited.push_back(arg_inst); +                queue.push(arg_inst); +            }          }      } +    // SSA tree has been traversed and the origin hasn't been found      return std::nullopt;  }  /// Collects the storage buffer used by a global memory instruction and the instruction itself -void CollectStorageBuffers(IR::Block& block, IR::Inst& inst, StorageBufferSet& storage_buffer_set, -                           StorageInstVector& to_replace, StorageWritesSet& writes_set) { +void CollectStorageBuffers(IR::Block& block, IR::Inst& inst, StorageInfo& info) {      // NVN puts storage buffers in a specific range, we have to bias towards these addresses to      // avoid getting false positives      static constexpr Bias nvn_bias{ @@ -284,24 +300,23 @@ void CollectStorageBuffers(IR::Block& block, IR::Inst& inst, StorageBufferSet& s      }      // First try to find storage buffers in the NVN address      const IR::U32 low_addr{low_addr_info->value}; -    VisitedBlocks visited_blocks; -    std::optional storage_buffer{Track(&block, low_addr, &nvn_bias, visited_blocks)}; +    std::optional storage_buffer{Track(low_addr, &nvn_bias)};      if (!storage_buffer) {          // If it fails, track without a bias -        visited_blocks.clear(); -        storage_buffer = Track(&block, low_addr, nullptr, visited_blocks); +        storage_buffer = Track(low_addr, nullptr);          if (!storage_buffer) {              // If that also failed, drop the global memory usage +            // LOG_ERROR              DiscardGlobalMemory(block, inst);              return;          }      }      // Collect storage buffer and the instruction      if (IsGlobalMemoryWrite(inst)) { -        writes_set.insert(*storage_buffer); +        info.writes.insert(*storage_buffer);      } -    storage_buffer_set.insert(*storage_buffer); -    to_replace.push_back(StorageInst{ +    info.set.insert(*storage_buffer); +    info.to_replace.push_back(StorageInst{          .storage_buffer{*storage_buffer},          .inst{&inst},          .block{&block}, @@ -371,33 +386,29 @@ void Replace(IR::Block& block, IR::Inst& inst, const IR::U32& storage_index,  } // Anonymous namespace  void GlobalMemoryToStorageBufferPass(IR::Program& program) { -    StorageBufferSet storage_buffers; -    StorageInstVector to_replace; -    StorageWritesSet writes_set; - +    StorageInfo info;      for (IR::Block* const block : program.post_order_blocks) {          for (IR::Inst& inst : block->Instructions()) {              if (!IsGlobalMemory(inst)) {                  continue;              } -            CollectStorageBuffers(*block, inst, storage_buffers, to_replace, writes_set); +            CollectStorageBuffers(*block, inst, info);          }      } -    Info& info{program.info};      u32 storage_index{}; -    for (const StorageBufferAddr& storage_buffer : storage_buffers) { -        info.storage_buffers_descriptors.push_back({ +    for (const StorageBufferAddr& storage_buffer : info.set) { +        program.info.storage_buffers_descriptors.push_back({              .cbuf_index{storage_buffer.index},              .cbuf_offset{storage_buffer.offset},              .count{1}, -            .is_written{writes_set.contains(storage_buffer)}, +            .is_written{info.writes.contains(storage_buffer)},          });          ++storage_index;      } -    for (const StorageInst& storage_inst : to_replace) { +    for (const StorageInst& storage_inst : info.to_replace) {          const StorageBufferAddr storage_buffer{storage_inst.storage_buffer}; -        const auto it{storage_buffers.find(storage_inst.storage_buffer)}; -        const IR::U32 index{IR::Value{static_cast<u32>(storage_buffers.index_of(it))}}; +        const auto it{info.set.find(storage_inst.storage_buffer)}; +        const IR::U32 index{IR::Value{static_cast<u32>(info.set.index_of(it))}};          IR::Block* const block{storage_inst.block};          IR::Inst* const inst{storage_inst.inst};          const IR::U32 offset{StorageOffset(*block, *inst, storage_buffer)}; | 
