Skip to content

microsoft/compiler: Less bitcasting

Jesse Natalie requested to merge jenatali/mesa:dxil-less-bitcasting into main

I was inspired by seeing !22934 (merged) and decided to try to remove bitcasting from our DXIL. This gets much closer.

Looking at piglit glsl-1.10\execution\glsl-fs-if-nested-loop.shader_test:

Before
define void @main() {
  %1 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0, i32 undef)  ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)
  %2 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1, i32 undef)  ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)
  %3 = bitcast i32 0 to float
  %4 = fcmp olt float %2, %3
  br i1 %4, label %5, label %31

; <label>:5                                       ; preds = %0
  br label %6

; <label>:6                                       ; preds = %21, %5
  %7 = phi i32 [ 0, %5 ], [ %25, %21 ]
  %8 = phi i32 [ -1073741824, %5 ], [ %29, %21 ]
  %9 = bitcast i32 %8 to float
  %10 = fcmp olt float %9, %2
  %11 = xor i1 %10, true
  br i1 %11, label %12, label %13

; <label>:12                                      ; preds = %6
  br label %30

; <label>:13                                      ; preds = %6
  br label %14

; <label>:14                                      ; preds = %13
  %15 = bitcast i32 1084227584 to float
  %16 = fadd fast float %1, %15
  %17 = bitcast i32 %8 to float
  %18 = fcmp une float %17, %16
  br i1 %18, label %19, label %20

; <label>:19                                      ; preds = %14
  br label %30

; <label>:20                                      ; preds = %14
  br label %21

; <label>:21                                      ; preds = %20
  %22 = bitcast i32 %7 to float
  %23 = bitcast i32 1065353216 to float
  %24 = fadd fast float %22, %23
  %25 = bitcast float %24 to i32
  %26 = bitcast i32 %8 to float
  %27 = bitcast i32 1065353216 to float
  %28 = fadd fast float %26, %27
  %29 = bitcast float %28 to i32
  br label %6

; <label>:30                                      ; preds = %19, %12
  br label %32

; <label>:31                                      ; preds = %0
  br label %32

; <label>:32                                      ; preds = %31, %30
  %33 = phi i32 [ %7, %30 ], [ 0, %31 ]
  %34 = bitcast i32 %33 to float
  call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %34)  ; StoreOutput(outputSigId,rowIndex,colIndex,value)
  %35 = bitcast i32 1065353216 to float
  call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float %35)  ; StoreOutput(outputSigId,rowIndex,colIndex,value)
  %36 = bitcast i32 0 to float
  call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float %36)  ; StoreOutput(outputSigId,rowIndex,colIndex,value)
  %37 = bitcast i32 0 to float
  call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float %37)  ; StoreOutput(outputSigId,rowIndex,colIndex,value)
  ret void
}
After
define void @main() {
  %1 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0, i32 undef)  ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)
  %2 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1, i32 undef)  ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)
  %3 = fcmp olt float %2, 0.000000e+00
  br i1 %3, label %4, label %21

; <label>:4                                       ; preds = %0
  br label %5

; <label>:5                                       ; preds = %17, %4
  %6 = phi float [ 0.000000e+00, %4 ], [ %18, %17 ]
  %7 = phi float [ -2.000000e+00, %4 ], [ %19, %17 ]
  %8 = fcmp olt float %7, %2
  %9 = xor i1 %8, true
  br i1 %9, label %10, label %11

; <label>:10                                      ; preds = %5
  br label %20

; <label>:11                                      ; preds = %5
  br label %12

; <label>:12                                      ; preds = %11
  %13 = fadd fast float %1, 5.000000e+00
  %14 = fcmp une float %7, %13
  br i1 %14, label %15, label %16

; <label>:15                                      ; preds = %12
  br label %20

; <label>:16                                      ; preds = %12
  br label %17

; <label>:17                                      ; preds = %16
  %18 = fadd fast float %6, 1.000000e+00
  %19 = fadd fast float %7, 1.000000e+00
  br label %5

; <label>:20                                      ; preds = %15, %10
  br label %22

; <label>:21                                      ; preds = %0
  br label %22

; <label>:22                                      ; preds = %21, %20
  %23 = phi float [ %6, %20 ], [ 0.000000e+00, %21 ]
  call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %23)  ; StoreOutput(outputSigId,rowIndex,colIndex,value)
  call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float 1.000000e+00)  ; StoreOutput(outputSigId,rowIndex,colIndex,value)
  call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float 0.000000e+00)  ; StoreOutput(outputSigId,rowIndex,colIndex,value)
  call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float 0.000000e+00)  ; StoreOutput(outputSigId,rowIndex,colIndex,value)
  ret void
}

The first commit is just cleanup. The second commit preps the NIR by making it uglier so that the last commit can work its magic. Essentially, it takes type information from instructions that know the types of their sources and propagates that backwards. For phis/movs, this then propagates the def type to its sources. That allows instructions that could have multiple types (loads, constants) to use that as a hint for what kind of data they should return, which allows us to avoid returning one type and immediately bitcasting it to a different one.

Edited by Jesse Natalie

Merge request reports