Commit 8a159502 authored by Bas Nieuwenhuizen's avatar Bas Nieuwenhuizen
Browse files

amd/common: Implement global memory accesses.



Needed for VK_EXT_buffer_device_address.

The pointers are implmemented as i8*, since I could not figure
out how to emulate setting struct offsets in LLVM based on the
SPIR-V offsets (and more weird stuff like row major matrices).
Acked-by: Samuel Pitoiset's avatarSamuel Pitoiset <samuel.pitoiset@gmail.com>
parent 5703ecf6
......@@ -1878,6 +1878,14 @@ static LLVMValueRef load_tess_varyings(struct ac_nir_context *ctx,
return LLVMBuildBitCast(ctx->ac.builder, result, dest_type, "");
}
static unsigned
type_scalar_size_bytes(const struct glsl_type *type)
{
assert(glsl_type_is_vector_or_scalar(type) ||
glsl_type_is_matrix(type));
return glsl_type_is_boolean(type) ? 4 : glsl_get_bit_size(type) / 8;
}
static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
nir_intrinsic_instr *instr)
{
......@@ -1892,7 +1900,7 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
LLVMValueRef ret;
unsigned const_index;
unsigned stride = 4;
int mode = nir_var_mem_shared;
int mode = deref->mode;
if (var) {
bool vs_in = ctx->stage == MESA_SHADER_VERTEX &&
......@@ -1999,6 +2007,32 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
}
}
break;
case nir_var_mem_global: {
LLVMValueRef address = get_src(ctx, instr->src[0]);
unsigned explicit_stride = glsl_get_explicit_stride(deref->type);
unsigned natural_stride = type_scalar_size_bytes(deref->type);
unsigned stride = explicit_stride ? explicit_stride : natural_stride;
LLVMTypeRef result_type = get_def_type(ctx, &instr->dest.ssa);
if (stride != natural_stride) {
LLVMTypeRef ptr_type = LLVMPointerType(LLVMGetElementType(result_type),
LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
for (unsigned i = 0; i < instr->dest.ssa.num_components; ++i) {
LLVMValueRef offset = LLVMConstInt(ctx->ac.i32, i * stride / natural_stride, 0);
values[i] = LLVMBuildLoad(ctx->ac.builder,
ac_build_gep_ptr(&ctx->ac, address, offset), "");
}
return ac_build_gather_values(&ctx->ac, values, instr->dest.ssa.num_components);
} else {
LLVMTypeRef ptr_type = LLVMPointerType(result_type,
LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
LLVMValueRef val = LLVMBuildLoad(ctx->ac.builder, address, "");
return val;
}
}
default:
unreachable("unhandle variable mode");
}
......@@ -2114,33 +2148,52 @@ visit_store_var(struct ac_nir_context *ctx,
}
}
break;
case nir_var_mem_global:
case nir_var_mem_shared: {
int writemask = instr->const_index[0];
LLVMValueRef address = get_src(ctx, instr->src[0]);
LLVMValueRef val = get_src(ctx, instr->src[1]);
if (writemask == (1u << ac_get_llvm_num_components(val)) - 1) {
val = LLVMBuildBitCast(
ctx->ac.builder, val,
LLVMGetElementType(LLVMTypeOf(address)), "");
unsigned explicit_stride = glsl_get_explicit_stride(deref->type);
unsigned natural_stride = type_scalar_size_bytes(deref->type);
unsigned stride = explicit_stride ? explicit_stride : natural_stride;
LLVMTypeRef ptr_type = LLVMPointerType(LLVMTypeOf(val),
LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
if (writemask == (1u << ac_get_llvm_num_components(val)) - 1 &&
stride == natural_stride) {
LLVMTypeRef ptr_type = LLVMPointerType(LLVMTypeOf(val),
LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
val = LLVMBuildBitCast(ctx->ac.builder, val,
LLVMGetElementType(LLVMTypeOf(address)), "");
LLVMBuildStore(ctx->ac.builder, val, address);
} else {
LLVMTypeRef ptr_type = LLVMPointerType(LLVMGetElementType(LLVMTypeOf(val)),
LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
for (unsigned chan = 0; chan < 4; chan++) {
if (!(writemask & (1 << chan)))
continue;
LLVMValueRef ptr =
LLVMBuildStructGEP(ctx->ac.builder,
address, chan, "");
LLVMValueRef offset = LLVMConstInt(ctx->ac.i32, chan * stride / natural_stride, 0);
LLVMValueRef ptr = ac_build_gep_ptr(&ctx->ac, address, offset);
LLVMValueRef src = ac_llvm_extract_elem(&ctx->ac, val,
chan);
src = LLVMBuildBitCast(
ctx->ac.builder, src,
LLVMGetElementType(LLVMTypeOf(ptr)), "");
src = LLVMBuildBitCast(ctx->ac.builder, src,
LLVMGetElementType(LLVMTypeOf(ptr)), "");
LLVMBuildStore(ctx->ac.builder, src, ptr);
}
}
break;
}
default:
abort();
break;
}
}
......@@ -3899,7 +3952,8 @@ glsl_to_llvm_type(struct ac_llvm_context *ac,
static void visit_deref(struct ac_nir_context *ctx,
nir_deref_instr *instr)
{
if (instr->mode != nir_var_mem_shared)
if (instr->mode != nir_var_mem_shared &&
instr->mode != nir_var_mem_global)
return;
LLVMValueRef result = NULL;
......@@ -3910,22 +3964,79 @@ static void visit_deref(struct ac_nir_context *ctx,
break;
}
case nir_deref_type_struct:
result = ac_build_gep0(&ctx->ac, get_src(ctx, instr->parent),
LLVMConstInt(ctx->ac.i32, instr->strct.index, 0));
if (instr->mode == nir_var_mem_global) {
nir_deref_instr *parent = nir_deref_instr_parent(instr);
uint64_t offset = glsl_get_struct_field_offset(parent->type,
instr->strct.index);
result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent),
LLVMConstInt(ctx->ac.i32, offset, 0));
} else {
result = ac_build_gep0(&ctx->ac, get_src(ctx, instr->parent),
LLVMConstInt(ctx->ac.i32, instr->strct.index, 0));
}
break;
case nir_deref_type_array:
result = ac_build_gep0(&ctx->ac, get_src(ctx, instr->parent),
get_src(ctx, instr->arr.index));
if (instr->mode == nir_var_mem_global) {
nir_deref_instr *parent = nir_deref_instr_parent(instr);
unsigned stride = glsl_get_explicit_stride(parent->type);
if ((glsl_type_is_matrix(parent->type) &&
glsl_matrix_type_is_row_major(parent->type)) ||
(glsl_type_is_vector(parent->type) && stride == 0))
stride = type_scalar_size_bytes(parent->type);
assert(stride > 0);
LLVMValueRef index = get_src(ctx, instr->arr.index);
if (LLVMTypeOf(index) != ctx->ac.i64)
index = LLVMBuildZExt(ctx->ac.builder, index, ctx->ac.i64, "");
LLVMValueRef offset = LLVMBuildMul(ctx->ac.builder, index, LLVMConstInt(ctx->ac.i64, stride, 0), "");
result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent), offset);
} else {
result = ac_build_gep0(&ctx->ac, get_src(ctx, instr->parent),
get_src(ctx, instr->arr.index));
}
break;
case nir_deref_type_ptr_as_array:
result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent),
get_src(ctx, instr->arr.index));
if (instr->mode == nir_var_mem_global) {
unsigned stride = nir_deref_instr_ptr_as_array_stride(instr);
LLVMValueRef index = get_src(ctx, instr->arr.index);
if (LLVMTypeOf(index) != ctx->ac.i64)
index = LLVMBuildZExt(ctx->ac.builder, index, ctx->ac.i64, "");
LLVMValueRef offset = LLVMBuildMul(ctx->ac.builder, index, LLVMConstInt(ctx->ac.i64, stride, 0), "");
result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent), offset);
} else {
result = ac_build_gep_ptr(&ctx->ac, get_src(ctx, instr->parent),
get_src(ctx, instr->arr.index));
}
break;
case nir_deref_type_cast: {
result = get_src(ctx, instr->parent);
LLVMTypeRef pointee_type = glsl_to_llvm_type(&ctx->ac, instr->type);
LLVMTypeRef type = LLVMPointerType(pointee_type, AC_ADDR_SPACE_LDS);
/* We can't use the structs from LLVM because the shader
* specifies its own offsets. */
LLVMTypeRef pointee_type = ctx->ac.i8;
if (instr->mode == nir_var_mem_shared)
pointee_type = glsl_to_llvm_type(&ctx->ac, instr->type);
unsigned address_space;
switch(instr->mode) {
case nir_var_mem_shared:
address_space = AC_ADDR_SPACE_LDS;
break;
case nir_var_mem_global:
address_space = AC_ADDR_SPACE_GLOBAL;
break;
default:
unreachable("Unhandled address space");
}
LLVMTypeRef type = LLVMPointerType(pointee_type, address_space);
if (LLVMTypeOf(result) != type) {
if (LLVMGetTypeKind(LLVMTypeOf(result)) == LLVMVectorTypeKind) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment