diff --git a/src/microsoft/clc/clc_compiler.c b/src/microsoft/clc/clc_compiler.c index 248c63654eb40d64250597b69f12f6ba764645ec..93bde0afb98753a4804028ac55e21eea821e835c 100644 --- a/src/microsoft/clc/clc_compiler.c +++ b/src/microsoft/clc/clc_compiler.c @@ -54,7 +54,7 @@ static const struct debug_named_value clc_debug_options[] = { DEBUG_GET_ONCE_FLAGS_OPTION(debug_clc, "CLC_DEBUG", clc_debug_options, 0) static void -clc_print_kernels_info(const struct clc_object *obj) +clc_print_kernels_info(const struct clc_parsed_spirv *obj) { fprintf(stdout, "Kernels:\n"); for (unsigned i = 0; i < obj->num_kernels; i++) { @@ -451,7 +451,7 @@ clc_lower_nonnormalized_samplers(nir_shader *nir, static void -clc_context_optimize(nir_shader *s) +clc_libclc_optimize(nir_shader *s) { bool progress; do { @@ -475,12 +475,16 @@ clc_context_optimize(nir_shader *s) } while (progress); } -struct clc_context * -clc_context_new(const struct clc_logger *logger, const struct clc_context_options *options) +struct clc_libclc { + const void *libclc_nir; +}; + +struct clc_libclc * +clc_libclc_new(const struct clc_logger *logger, const struct clc_libclc_options *options) { - struct clc_context *ctx = rzalloc(NULL, struct clc_context); + struct clc_libclc *ctx = rzalloc(NULL, struct clc_libclc); if (!ctx) { - clc_error(logger, "D3D12: failed to allocate a clc_context"); + clc_error(logger, "D3D12: failed to allocate a clc_libclc"); return NULL; } @@ -513,7 +517,7 @@ clc_context_new(const struct clc_logger *logger, const struct clc_context_option } if (options && options->optimize) - clc_context_optimize(s); + clc_libclc_optimize(s); ralloc_steal(ctx, s); ctx->libclc_nir = s; @@ -522,13 +526,13 @@ clc_context_new(const struct clc_logger *logger, const struct clc_context_option } void -clc_free_context(struct clc_context *ctx) +clc_free_libclc(struct clc_libclc *ctx) { ralloc_free(ctx); glsl_type_singleton_decref(); }; -void clc_context_serialize(struct clc_context *context, +void clc_libclc_serialize(struct clc_libclc *context, void **serialized, size_t *serialized_size) { @@ -539,15 +543,15 @@ void clc_context_serialize(struct clc_context *context, blob_finish_get_buffer(&tmp, serialized, serialized_size); } -void clc_context_free_serialized(void *serialized) +void clc_libclc_free_serialized(void *serialized) { free(serialized); } -struct clc_context * - clc_context_deserialize(const void *serialized, size_t serialized_size) +struct clc_libclc * + clc_libclc_deserialize(const void *serialized, size_t serialized_size) { - struct clc_context *ctx = rzalloc(NULL, struct clc_context); + struct clc_libclc *ctx = rzalloc(NULL, struct clc_libclc); if (!ctx) { return NULL; } @@ -571,69 +575,105 @@ struct clc_context * return ctx; } -struct clc_object * -clc_compile(struct clc_context *ctx, - const struct clc_compile_args *args, - const struct clc_logger *logger) +bool +clc_compile_c_to_spir(const struct clc_compile_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spir) { - struct clc_object *obj; - int ret; + return clc_c_to_spir(args, logger, out_spir) >= 0; +} - obj = calloc(1, sizeof(*obj)); - if (!obj) { - clc_error(logger, "D3D12: failed to allocate a clc_object"); - return NULL; - } +void +clc_free_spir(struct clc_binary *spir) +{ + clc_free_spir_binary(spir); +} - ret = clc_to_spirv(args, &obj->spvbin, logger); - if (ret < 0) { - free(obj); - return NULL; - } +bool +clc_compile_spir_to_spirv(const struct clc_binary *in_spir, + const struct clc_logger *logger, + struct clc_binary *out_spirv) +{ + if (clc_spir_to_spirv(in_spir, logger, out_spirv) < 0) + return false; if (debug_get_option_debug_clc() & CLC_DEBUG_DUMP_SPIRV) - clc_dump_spirv(&obj->spvbin, stdout); + clc_dump_spirv(out_spirv, stdout); - return obj; + return true; } -struct clc_object * -clc_link(struct clc_context *ctx, - const struct clc_linker_args *args, - const struct clc_logger *logger) +void +clc_free_spirv(struct clc_binary *spirv) { - struct clc_object *out_obj; - int ret; + clc_free_spirv_binary(spirv); +} - out_obj = malloc(sizeof(*out_obj)); - if (!out_obj) { - clc_error(logger, "failed to allocate a clc_object"); - return NULL; - } +bool +clc_compile_c_to_spirv(const struct clc_compile_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spirv) +{ + if (clc_c_to_spirv(args, logger, out_spirv) < 0) + return false; - ret = clc_link_spirv_binaries(args, &out_obj->spvbin, logger); - if (ret < 0) { - free(out_obj); - return NULL; - } + if (debug_get_option_debug_clc() & CLC_DEBUG_DUMP_SPIRV) + clc_dump_spirv(out_spirv, stdout); + + return true; +} + +bool +clc_link_spirv(const struct clc_linker_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spirv) +{ + if (clc_link_spirv_binaries(args, logger, out_spirv) < 0) + return false; if (debug_get_option_debug_clc() & CLC_DEBUG_DUMP_SPIRV) - clc_dump_spirv(&out_obj->spvbin, stdout); + clc_dump_spirv(out_spirv, stdout); + + return true; +} - out_obj->kernels = clc_spirv_get_kernels_info(&out_obj->spvbin, - &out_obj->num_kernels); +bool +clc_parse_spirv(const struct clc_binary *in_spirv, + const struct clc_logger *logger, + struct clc_parsed_spirv *out_data) +{ + if (!clc_spirv_get_kernels_info(in_spirv, + &out_data->kernels, + &out_data->num_kernels, + &out_data->spec_constants, + &out_data->num_spec_constants, + logger)) + return false; if (debug_get_option_debug_clc() & CLC_DEBUG_VERBOSE) - clc_print_kernels_info(out_obj); + clc_print_kernels_info(out_data); - return out_obj; + return true; } -void clc_free_object(struct clc_object *obj) +void clc_free_parsed_spirv(struct clc_parsed_spirv *data) { - clc_free_kernels_info(obj->kernels, obj->num_kernels); - clc_free_spirv_binary(&obj->spvbin); - free(obj); + clc_free_kernels_info(data->kernels, data->num_kernels); +} + +bool +clc_specialize_spirv(const struct clc_binary *in_spirv, + const struct clc_parsed_spirv *parsed_data, + const struct clc_spirv_specialization_consts *consts, + struct clc_binary *out_spirv) +{ + if (!clc_spirv_specialize(in_spirv, parsed_data, consts, out_spirv)) + return false; + + if (debug_get_option_debug_clc() & CLC_DEBUG_DUMP_SPIRV) + clc_dump_spirv(out_spirv, stdout); + + return true; } static nir_variable * @@ -1009,37 +1049,33 @@ scale_fdiv(nir_shader *nir) return progress; } -struct clc_dxil_object * -clc_to_dxil(struct clc_context *ctx, - const struct clc_object *obj, - const char *entrypoint, - const struct clc_runtime_kernel_conf *conf, - const struct clc_logger *logger) +bool +clc_spirv_to_dxil(struct clc_libclc *lib, + const struct clc_binary *linked_spirv, + const struct clc_parsed_spirv *parsed_data, + const char *entrypoint, + const struct clc_runtime_kernel_conf *conf, + const struct clc_spirv_specialization_consts *consts, + const struct clc_logger *logger, + struct clc_dxil_object *out_dxil) { - struct clc_dxil_object *dxil; struct nir_shader *nir; - dxil = calloc(1, sizeof(*dxil)); - if (!dxil) { - clc_error(logger, "failed to allocate the dxil object"); - return NULL; - } - - for (unsigned i = 0; i < obj->num_kernels; i++) { - if (!strcmp(obj->kernels[i].name, entrypoint)) { - dxil->kernel = &obj->kernels[i]; + for (unsigned i = 0; i < parsed_data->num_kernels; i++) { + if (!strcmp(parsed_data->kernels[i].name, entrypoint)) { + out_dxil->kernel = &parsed_data->kernels[i]; break; } } - if (!dxil->kernel) { + if (!out_dxil->kernel) { clc_error(logger, "no '%s' kernel found", entrypoint); - goto err_free_dxil; + return false; } const struct spirv_to_nir_options spirv_options = { .environment = NIR_SPIRV_OPENCL, - .clc_shader = ctx->libclc_nir, + .clc_shader = lib->libclc_nir, .constant_addr_format = nir_address_format_32bit_index_offset_pack64, .global_addr_format = nir_address_format_32bit_index_offset_pack64, .shared_addr_format = nir_address_format_32bit_offset_as_64bit, @@ -1072,8 +1108,9 @@ clc_to_dxil(struct clc_context *ctx, glsl_type_singleton_init_or_ref(); - nir = spirv_to_nir(obj->spvbin.data, obj->spvbin.size / 4, - NULL, 0, + nir = spirv_to_nir(linked_spirv->data, linked_spirv->size / 4, + consts ? (struct nir_spirv_specialization *)consts->specializations : NULL, + consts ? consts->num_specializations : 0, MESA_SHADER_KERNEL, entrypoint, &spirv_options, &nir_options); @@ -1086,9 +1123,9 @@ clc_to_dxil(struct clc_context *ctx, NIR_PASS_V(nir, nir_lower_goto_ifs); NIR_PASS_V(nir, nir_opt_dead_cf); - struct clc_dxil_metadata *metadata = &dxil->metadata; + struct clc_dxil_metadata *metadata = &out_dxil->metadata; - metadata->args = calloc(dxil->kernel->num_args, + metadata->args = calloc(out_dxil->kernel->num_args, sizeof(*metadata->args)); if (!metadata->args) { clc_error(logger, "failed to allocate arg positions"); @@ -1116,7 +1153,7 @@ clc_to_dxil(struct clc_context *ctx, // according to the comment on nir_inline_functions NIR_PASS_V(nir, nir_lower_variable_initializers, nir_var_function_temp); NIR_PASS_V(nir, nir_lower_returns); - NIR_PASS_V(nir, nir_lower_libclc, ctx->libclc_nir); + NIR_PASS_V(nir, nir_lower_libclc, lib->libclc_nir); NIR_PASS_V(nir, nir_inline_functions); // Pick off the single entrypoint that we want. @@ -1218,8 +1255,8 @@ clc_to_dxil(struct clc_context *ctx, metadata->args[i].size = size; metadata->kernel_inputs_buf_size = MAX2(metadata->kernel_inputs_buf_size, var->data.driver_location + size); - if ((dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_GLOBAL || - dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_CONSTANT) && + if ((out_dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_GLOBAL || + out_dxil->kernel->args[i].address_qualifier == CLC_KERNEL_ARG_ADDRESS_CONSTANT) && // Ignore images during this pass - global memory buffers need to have contiguous bindings !glsl_type_is_image(var->type)) { metadata->args[i].globconstptr.buf_id = uav_id++; @@ -1301,7 +1338,7 @@ clc_to_dxil(struct clc_context *ctx, glsl_get_cl_type_size_align); NIR_PASS_V(nir, dxil_nir_lower_ubo_to_temp); - NIR_PASS_V(nir, clc_lower_constant_to_ssbo, dxil->kernel, &uav_id); + NIR_PASS_V(nir, clc_lower_constant_to_ssbo, out_dxil->kernel, &uav_id); NIR_PASS_V(nir, clc_lower_global_to_ssbo); bool has_printf = false; @@ -1335,9 +1372,9 @@ clc_to_dxil(struct clc_context *ctx, unsigned cbv_id = 0; nir_variable *inputs_var = - add_kernel_inputs_var(dxil, nir, &cbv_id); + add_kernel_inputs_var(out_dxil, nir, &cbv_id); nir_variable *work_properties_var = - add_work_properties_var(dxil, nir, &cbv_id); + add_work_properties_var(out_dxil, nir, &cbv_id); memcpy(metadata->local_size, nir->info.workgroup_size, sizeof(metadata->local_size)); @@ -1396,12 +1433,12 @@ clc_to_dxil(struct clc_context *ctx, .num_kernel_globals = num_global_inputs, }; - for (unsigned i = 0; i < dxil->kernel->num_args; i++) { - if (dxil->kernel->args[i].address_qualifier != CLC_KERNEL_ARG_ADDRESS_LOCAL) + for (unsigned i = 0; i < out_dxil->kernel->num_args; i++) { + if (out_dxil->kernel->args[i].address_qualifier != CLC_KERNEL_ARG_ADDRESS_LOCAL) continue; /* If we don't have the runtime conf yet, we just create a dummy variable. - * This will be adjusted when clc_to_dxil() is called with a conf + * This will be adjusted when clc_spirv_to_dxil() is called with a conf * argument. */ unsigned size = 4; @@ -1467,13 +1504,13 @@ clc_to_dxil(struct clc_context *ctx, ralloc_free(nir); glsl_type_singleton_decref(); - blob_finish_get_buffer(&tmp, &dxil->binary.data, - &dxil->binary.size); - return dxil; + blob_finish_get_buffer(&tmp, &out_dxil->binary.data, + &out_dxil->binary.size); + return true; err_free_dxil: - clc_free_dxil_object(dxil); - return NULL; + clc_free_dxil_object(out_dxil); + return false; } void clc_free_dxil_object(struct clc_dxil_object *dxil) @@ -1488,7 +1525,6 @@ void clc_free_dxil_object(struct clc_dxil_object *dxil) free(dxil->metadata.printf.infos); free(dxil->binary.data); - free(dxil); } uint64_t clc_compiler_get_version() diff --git a/src/microsoft/clc/clc_compiler.h b/src/microsoft/clc/clc_compiler.h index fc772605b49ed40a56bf7bec6ca03e61b08cd40f..9029e6e640dcd385b4be21d1c1717032fdcc15cd 100644 --- a/src/microsoft/clc/clc_compiler.h +++ b/src/microsoft/clc/clc_compiler.h @@ -45,7 +45,7 @@ struct clc_compile_args { }; struct clc_linker_args { - const struct clc_object * const *in_objs; + const struct clc_binary * const *in_objs; unsigned num_in_objs; unsigned create_library; }; @@ -58,8 +58,8 @@ struct clc_logger { clc_msg_callback warning; }; -struct spirv_binary { - uint32_t *data; +struct clc_binary { + void *data; size_t size; }; @@ -108,10 +108,32 @@ struct clc_kernel_info { enum clc_vec_hint_type vec_hint_type; }; -struct clc_object { - struct spirv_binary spvbin; +enum clc_spec_constant_type { + CLC_SPEC_CONSTANT_UNKNOWN, + CLC_SPEC_CONSTANT_BOOL, + CLC_SPEC_CONSTANT_FLOAT, + CLC_SPEC_CONSTANT_DOUBLE, + CLC_SPEC_CONSTANT_INT8, + CLC_SPEC_CONSTANT_UINT8, + CLC_SPEC_CONSTANT_INT16, + CLC_SPEC_CONSTANT_UINT16, + CLC_SPEC_CONSTANT_INT32, + CLC_SPEC_CONSTANT_UINT32, + CLC_SPEC_CONSTANT_INT64, + CLC_SPEC_CONSTANT_UINT64, +}; + +struct clc_parsed_spec_constant { + uint32_t id; + enum clc_spec_constant_type type; +}; + +struct clc_parsed_spirv { const struct clc_kernel_info *kernels; unsigned num_kernels; + + const struct clc_parsed_spec_constant *spec_constants; + unsigned num_spec_constants; }; #define CLC_MAX_CONSTS 32 @@ -187,33 +209,53 @@ struct clc_dxil_object { } binary; }; -struct clc_context { - const void *libclc_nir; -}; +struct clc_libclc; -struct clc_context_options { +struct clc_libclc_options { unsigned optimize; }; -struct clc_context *clc_context_new(const struct clc_logger *logger, const struct clc_context_options *options); +struct clc_libclc *clc_libclc_new(const struct clc_logger *logger, const struct clc_libclc_options *options); + +void clc_free_libclc(struct clc_libclc *lib); + +void clc_libclc_serialize(struct clc_libclc *lib, void **serialized, size_t *size); +void clc_libclc_free_serialized(void *serialized); +struct clc_libclc *clc_libclc_deserialize(void *serialized, size_t size); -void clc_free_context(struct clc_context *ctx); +bool +clc_compile_c_to_spir(const struct clc_compile_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spir); -void clc_context_serialize(struct clc_context *ctx, void **serialized, size_t *size); -void clc_context_free_serialized(void *serialized); -struct clc_context *clc_context_deserialize(void *serialized, size_t size); +void +clc_free_spir(struct clc_binary *spir); -struct clc_object * -clc_compile(struct clc_context *ctx, - const struct clc_compile_args *args, - const struct clc_logger *logger); +bool +clc_compile_spir_to_spirv(const struct clc_binary *in_spir, + const struct clc_logger *logger, + struct clc_binary *out_spirv); -struct clc_object * -clc_link(struct clc_context *ctx, - const struct clc_linker_args *args, - const struct clc_logger *logger); +void +clc_free_spirv(struct clc_binary *spirv); -void clc_free_object(struct clc_object *obj); +bool +clc_compile_c_to_spirv(const struct clc_compile_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spirv); + +bool +clc_link_spirv(const struct clc_linker_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spirv); + +bool +clc_parse_spirv(const struct clc_binary *in_spirv, + const struct clc_logger *logger, + struct clc_parsed_spirv *out_data); + +void +clc_free_parsed_spirv(struct clc_parsed_spirv *data); struct clc_runtime_arg_info { union { @@ -236,12 +278,46 @@ struct clc_runtime_kernel_conf { unsigned support_workgroup_id_offsets; }; -struct clc_dxil_object * -clc_to_dxil(struct clc_context *ctx, - const struct clc_object *obj, - const char *entrypoint, - const struct clc_runtime_kernel_conf *conf, - const struct clc_logger *logger); +typedef union { + bool b; + float f32; + double f64; + int8_t i8; + uint8_t u8; + int16_t i16; + uint16_t u16; + int32_t i32; + uint32_t u32; + int64_t i64; + uint64_t u64; +} clc_spirv_const_value; + +struct clc_spirv_specialization { + uint32_t id; + clc_spirv_const_value value; + bool defined_on_module; +}; + +struct clc_spirv_specialization_consts { + const struct clc_spirv_specialization *specializations; + unsigned num_specializations; +}; + +bool +clc_specialize_spirv(const struct clc_binary *in_spirv, + const struct clc_parsed_spirv *parsed_data, + const struct clc_spirv_specialization_consts *consts, + struct clc_binary *out_spirv); + +bool +clc_spirv_to_dxil(struct clc_libclc *lib, + const struct clc_binary *linked_spirv, + const struct clc_parsed_spirv *parsed_data, + const char *entrypoint, + const struct clc_runtime_kernel_conf *conf, + const struct clc_spirv_specialization_consts *consts, + const struct clc_logger *logger, + struct clc_dxil_object *out_dxil); void clc_free_dxil_object(struct clc_dxil_object *dxil); diff --git a/src/microsoft/clc/clc_compiler_test.cpp b/src/microsoft/clc/clc_compiler_test.cpp index 4d3182f17c31d3db6ca4790faf30be5de6c83979..f5155c96c49333f1ef9184eab8cc39ddb14e836c 100644 --- a/src/microsoft/clc/clc_compiler_test.cpp +++ b/src/microsoft/clc/clc_compiler_test.cpp @@ -1688,8 +1688,8 @@ TEST_F(ComputeTest, vec_hint_float4) inout[get_global_id(0)] *= inout[get_global_id(1)];\n\ }"; Shader shader = compile({ kernel_source }); - EXPECT_EQ(shader.obj->kernels[0].vec_hint_size, 4); - EXPECT_EQ(shader.obj->kernels[0].vec_hint_type, CLC_VEC_HINT_TYPE_FLOAT); + EXPECT_EQ(shader.metadata->kernels[0].vec_hint_size, 4); + EXPECT_EQ(shader.metadata->kernels[0].vec_hint_type, CLC_VEC_HINT_TYPE_FLOAT); } TEST_F(ComputeTest, vec_hint_uchar2) @@ -1700,8 +1700,8 @@ TEST_F(ComputeTest, vec_hint_uchar2) inout[get_global_id(0)] *= inout[get_global_id(1)];\n\ }"; Shader shader = compile({ kernel_source }); - EXPECT_EQ(shader.obj->kernels[0].vec_hint_size, 2); - EXPECT_EQ(shader.obj->kernels[0].vec_hint_type, CLC_VEC_HINT_TYPE_CHAR); + EXPECT_EQ(shader.metadata->kernels[0].vec_hint_size, 2); + EXPECT_EQ(shader.metadata->kernels[0].vec_hint_type, CLC_VEC_HINT_TYPE_CHAR); } TEST_F(ComputeTest, vec_hint_none) @@ -1712,7 +1712,7 @@ TEST_F(ComputeTest, vec_hint_none) inout[get_global_id(0)] *= inout[get_global_id(1)];\n\ }"; Shader shader = compile({ kernel_source }); - EXPECT_EQ(shader.obj->kernels[0].vec_hint_size, 0); + EXPECT_EQ(shader.metadata->kernels[0].vec_hint_size, 0); } TEST_F(ComputeTest, DISABLED_debug_layer_failure) @@ -2232,3 +2232,82 @@ TEST_F(ComputeTest, unused_arg) for (int i = 0; i < 4; ++i) EXPECT_EQ(dest[i], i + 1); } + +TEST_F(ComputeTest, spec_constant) +{ + const char *spirv_asm = R"( + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %2 "main_test" %__spirv_BuiltInGlobalInvocationId + %4 = OpString "kernel_arg_type.main_test.uint*," + OpSource OpenCL_C 102000 + OpName %__spirv_BuiltInGlobalInvocationId "__spirv_BuiltInGlobalInvocationId" + OpName %output "output" + OpName %entry "entry" + OpName %output_addr "output.addr" + OpName %id "id" + OpName %call "call" + OpName %conv "conv" + OpName %idxprom "idxprom" + OpName %arrayidx "arrayidx" + OpName %add "add" + OpName %mul "mul" + OpName %idxprom1 "idxprom1" + OpName %arrayidx2 "arrayidx2" + OpDecorate %__spirv_BuiltInGlobalInvocationId BuiltIn GlobalInvocationId + OpDecorate %__spirv_BuiltInGlobalInvocationId Constant + OpDecorate %id Alignment 4 + OpDecorate %output_addr Alignment 8 + OpDecorate %uint_1 SpecId 1 + %ulong = OpTypeInt 64 0 + %uint = OpTypeInt 32 0 + %uint_1 = OpSpecConstant %uint 1 + %v3ulong = OpTypeVector %ulong 3 +%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong + %void = OpTypeVoid +%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint + %24 = OpTypeFunction %void %_ptr_CrossWorkgroup_uint +%_ptr_Function__ptr_CrossWorkgroup_uint = OpTypePointer Function %_ptr_CrossWorkgroup_uint +%_ptr_Function_uint = OpTypePointer Function %uint +%__spirv_BuiltInGlobalInvocationId = OpVariable %_ptr_Input_v3ulong Input + %2 = OpFunction %void DontInline %24 + %output = OpFunctionParameter %_ptr_CrossWorkgroup_uint + %entry = OpLabel +%output_addr = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uint Function + %id = OpVariable %_ptr_Function_uint Function + OpStore %output_addr %output Aligned 8 + %27 = OpLoad %v3ulong %__spirv_BuiltInGlobalInvocationId Aligned 32 + %call = OpCompositeExtract %ulong %27 0 + %conv = OpUConvert %uint %call + OpStore %id %conv Aligned 4 + %28 = OpLoad %_ptr_CrossWorkgroup_uint %output_addr Aligned 8 + %29 = OpLoad %uint %id Aligned 4 + %idxprom = OpUConvert %ulong %29 + %arrayidx = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uint %28 %idxprom + %30 = OpLoad %uint %arrayidx Aligned 4 + %31 = OpLoad %uint %id Aligned 4 + %add = OpIAdd %uint %31 %uint_1 + %mul = OpIMul %uint %30 %add + %32 = OpLoad %_ptr_CrossWorkgroup_uint %output_addr Aligned 8 + %33 = OpLoad %uint %id Aligned 4 + %idxprom1 = OpUConvert %ulong %33 + %arrayidx2 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uint %32 %idxprom1 + OpStore %arrayidx2 %mul Aligned 4 + OpReturn + OpFunctionEnd)"; + Shader shader = assemble(spirv_asm); + Shader spec_shader = specialize(shader, 1, 5); + + auto inout = ShaderArg({ 0x00000001, 0x10000001, 0x00020002, 0x04010203 }, + SHADER_ARG_INOUT); + const uint32_t expected[] = { + 0x00000005, 0x60000006, 0x000e000e, 0x20081018 + }; + CompileArgs args = { inout.size(), 1, 1 }; + run_shader(spec_shader, args, inout); + for (int i = 0; i < inout.size(); ++i) + EXPECT_EQ(inout[i], expected[i]); +} diff --git a/src/microsoft/clc/clc_helpers.cpp b/src/microsoft/clc/clc_helpers.cpp index 20bad26112b50e0c09ddf7e7b373b9e94c6a6779..8ea2ef10a869729d2be66c02f574d32a4d187541 100644 --- a/src/microsoft/clc/clc_helpers.cpp +++ b/src/microsoft/clc/clc_helpers.cpp @@ -31,6 +31,8 @@ #include #include #include +#include +#include #include #include #include @@ -44,6 +46,7 @@ #include #include +#include #include "util/macros.h" #include "glsl_types.h" @@ -56,6 +59,8 @@ #include "opencl-c.h.h" #include "opencl-c-base.h.h" +constexpr spv_target_env spirv_target = SPV_ENV_UNIVERSAL_1_0; + using ::llvm::Function; using ::llvm::LLVMContext; using ::llvm::Module; @@ -291,6 +296,18 @@ public: assert(op->type == SPV_OPERAND_TYPE_DECORATION); decoration = ins->words[op->offset]; + if (decoration == SpvDecorationSpecId) { + uint32_t spec_id = ins->words[ins->operands[2].offset]; + for (auto &c : specConstants) { + if (c.second.id == spec_id) { + assert(c.first == id); + return; + } + } + specConstants.emplace_back(id, clc_parsed_spec_constant{ spec_id }); + return; + } + for (auto &kernel : kernels) { for (auto &arg : kernel.args) { if (arg.id == id) { @@ -412,6 +429,104 @@ public: } } + void parseLiteralType(const spv_parsed_instruction_t *ins) + { + uint32_t typeId = ins->words[ins->operands[0].offset]; + auto& literalType = literalTypes[typeId]; + switch (ins->opcode) { + case SpvOpTypeBool: + literalType = CLC_SPEC_CONSTANT_BOOL; + break; + case SpvOpTypeFloat: { + uint32_t sizeInBits = ins->words[ins->operands[1].offset]; + switch (sizeInBits) { + case 32: + literalType = CLC_SPEC_CONSTANT_FLOAT; + break; + case 64: + literalType = CLC_SPEC_CONSTANT_DOUBLE; + break; + case 16: + /* Can't be used for a spec constant */ + break; + default: + unreachable("Unexpected float bit size"); + } + break; + } + case SpvOpTypeInt: { + uint32_t sizeInBits = ins->words[ins->operands[1].offset]; + bool isSigned = ins->words[ins->operands[2].offset]; + if (isSigned) { + switch (sizeInBits) { + case 8: + literalType = CLC_SPEC_CONSTANT_INT8; + break; + case 16: + literalType = CLC_SPEC_CONSTANT_INT16; + break; + case 32: + literalType = CLC_SPEC_CONSTANT_INT32; + break; + case 64: + literalType = CLC_SPEC_CONSTANT_INT64; + break; + default: + unreachable("Unexpected int bit size"); + } + } else { + switch (sizeInBits) { + case 8: + literalType = CLC_SPEC_CONSTANT_UINT8; + break; + case 16: + literalType = CLC_SPEC_CONSTANT_UINT16; + break; + case 32: + literalType = CLC_SPEC_CONSTANT_UINT32; + break; + case 64: + literalType = CLC_SPEC_CONSTANT_UINT64; + break; + default: + unreachable("Unexpected uint bit size"); + } + } + break; + } + default: + unreachable("Unexpected type opcode"); + } + } + + void parseSpecConstant(const spv_parsed_instruction_t *ins) + { + uint32_t id = ins->result_id; + for (auto& c : specConstants) { + if (c.first == id) { + auto& data = c.second; + switch (ins->opcode) { + case SpvOpSpecConstant: { + uint32_t typeId = ins->words[ins->operands[0].offset]; + + // This better be an integer or float type + auto typeIter = literalTypes.find(typeId); + assert(typeIter != literalTypes.end()); + + data.type = typeIter->second; + break; + } + case SpvOpSpecConstantFalse: + case SpvOpSpecConstantTrue: + data.type = CLC_SPEC_CONSTANT_BOOL; + break; + default: + unreachable("Composites and Ops are not directly specializable."); + } + } + } + } + static spv_result_t parseInstruction(void *data, const spv_parsed_instruction_t *ins) { @@ -452,6 +567,16 @@ public: case SpvOpExecutionMode: parser->parseExecutionMode(ins); break; + case SpvOpTypeBool: + case SpvOpTypeInt: + case SpvOpTypeFloat: + parser->parseLiteralType(ins); + break; + case SpvOpSpecConstant: + case SpvOpSpecConstantFalse: + case SpvOpSpecConstantTrue: + parser->parseSpecConstant(ins); + break; default: break; } @@ -474,7 +599,7 @@ public: return true; } - void parseBinary(const struct spirv_binary &spvbin) + bool parseBinary(const struct clc_binary &spvbin, const struct clc_logger *logger) { /* 3 passes should be enough to retrieve all kernel information: * 1st pass: all entry point name and number of args @@ -482,35 +607,53 @@ public: * 3rd pass: pointer type names */ for (unsigned pass = 0; pass < 3; pass++) { - spvBinaryParse(ctx, reinterpret_cast(this), - spvbin.data, spvbin.size / 4, - NULL, parseInstruction, NULL); + spv_diagnostic diagnostic = NULL; + auto result = spvBinaryParse(ctx, reinterpret_cast(this), + static_cast(spvbin.data), spvbin.size / 4, + NULL, parseInstruction, &diagnostic); + + if (result != SPV_SUCCESS) { + if (diagnostic && logger) + logger->error(logger->priv, diagnostic->error); + return false; + } if (parsingComplete()) - return; + return true; } assert(0); + return false; } std::vector kernels; + std::vector> specConstants; + std::map literalTypes; std::map> decorationGroups; SPIRVKernelInfo *curKernel; spv_context ctx; }; -const struct clc_kernel_info * -clc_spirv_get_kernels_info(const struct spirv_binary *spvbin, - unsigned *num_kernels) +bool +clc_spirv_get_kernels_info(const struct clc_binary *spvbin, + const struct clc_kernel_info **out_kernels, + unsigned *num_kernels, + const struct clc_parsed_spec_constant **out_spec_constants, + unsigned *num_spec_constants, + const struct clc_logger *logger) { struct clc_kernel_info *kernels; + struct clc_parsed_spec_constant *spec_constants; SPIRVKernelParser parser; - parser.parseBinary(*spvbin); + if (!parser.parseBinary(*spvbin, logger)) + return false; + *num_kernels = parser.kernels.size(); + *num_spec_constants = parser.specConstants.size(); if (!*num_kernels) - return NULL; + return false; kernels = reinterpret_cast(calloc(*num_kernels, sizeof(*kernels))); @@ -539,7 +682,20 @@ clc_spirv_get_kernels_info(const struct spirv_binary *spvbin, } } - return kernels; + if (*num_spec_constants) { + spec_constants = reinterpret_cast(calloc(*num_spec_constants, + sizeof(*spec_constants))); + assert(spec_constants); + + for (unsigned i = 0; i < parser.specConstants.size(); ++i) { + spec_constants[i] = parser.specConstants[i].second; + } + } + + *out_kernels = kernels; + *out_spec_constants = spec_constants; + + return true; } void @@ -562,10 +718,9 @@ clc_free_kernels_info(const struct clc_kernel_info *kernels, free((void *)kernels); } -int -clc_to_spirv(const struct clc_compile_args *args, - struct spirv_binary *spvbin, - const struct clc_logger *logger) +static std::pair, std::unique_ptr> +clc_compile_to_llvm_module(const struct clc_compile_args *args, + const struct clc_logger *logger) { LLVMInitializeAllTargets(); LLVMInitializeAllTargetInfos(); @@ -612,13 +767,13 @@ clc_to_spirv(const struct clc_compile_args *args, diag)) { log += "Couldn't create Clang invocation.\n"; clc_error(logger, log.c_str()); - return -1; + return {}; } if (diag.hasErrorOccurred()) { log += "Errors occurred during Clang invocation.\n"; clc_error(logger, log.c_str()); - return -1; + return {}; } // This is a workaround for a Clang bug which causes the number @@ -682,10 +837,19 @@ clc_to_spirv(const struct clc_compile_args *args, if (!c->ExecuteAction(act)) { log += "Error executing LLVM compilation action.\n"; clc_error(logger, log.c_str()); - return -1; + return {}; } - auto mod = act.takeModule(); + return { act.takeModule(), std::move(llvm_ctx) }; +} + +static int +llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod, + std::unique_ptr context, + const struct clc_logger *logger, + struct clc_binary *out_spirv) +{ + std::string log; std::ostringstream spv_stream; if (!::llvm::writeSpirv(mod.get(), spv_stream, log)) { log += "Translation from LLVM IR to SPIR-V failed.\n"; @@ -694,41 +858,61 @@ clc_to_spirv(const struct clc_compile_args *args, } const std::string spv_out = spv_stream.str(); - spvbin->size = spv_out.size(); - spvbin->data = static_cast(malloc(spvbin->size)); - memcpy(spvbin->data, spv_out.data(), spvbin->size); + out_spirv->size = spv_out.size(); + out_spirv->data = malloc(out_spirv->size); + memcpy(out_spirv->data, spv_out.data(), out_spirv->size); return 0; } -static const char * -spv_result_to_str(spv_result_t res) +int +clc_c_to_spir(const struct clc_compile_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spir) { - switch (res) { - case SPV_SUCCESS: return "success"; - case SPV_UNSUPPORTED: return "unsupported"; - case SPV_END_OF_STREAM: return "end of stream"; - case SPV_WARNING: return "warning"; - case SPV_FAILED_MATCH: return "failed match"; - case SPV_REQUESTED_TERMINATION: return "requested termination"; - case SPV_ERROR_INTERNAL: return "internal error"; - case SPV_ERROR_OUT_OF_MEMORY: return "out of memory"; - case SPV_ERROR_INVALID_POINTER: return "invalid pointer"; - case SPV_ERROR_INVALID_BINARY: return "invalid binary"; - case SPV_ERROR_INVALID_TEXT: return "invalid text"; - case SPV_ERROR_INVALID_TABLE: return "invalid table"; - case SPV_ERROR_INVALID_VALUE: return "invalid value"; - case SPV_ERROR_INVALID_DIAGNOSTIC: return "invalid diagnostic"; - case SPV_ERROR_INVALID_LOOKUP: return "invalid lookup"; - case SPV_ERROR_INVALID_ID: return "invalid id"; - case SPV_ERROR_INVALID_CFG: return "invalid config"; - case SPV_ERROR_INVALID_LAYOUT: return "invalid layout"; - case SPV_ERROR_INVALID_CAPABILITY: return "invalid capability"; - case SPV_ERROR_INVALID_DATA: return "invalid data"; - case SPV_ERROR_MISSING_EXTENSION: return "missing extension"; - case SPV_ERROR_WRONG_VERSION: return "wrong version"; - default: return "unknown error"; - } + auto pair = clc_compile_to_llvm_module(args, logger); + if (!pair.first) + return -1; + + ::llvm::SmallVector buffer; + ::llvm::BitcodeWriter writer(buffer); + writer.writeModule(*pair.first); + + out_spir->size = buffer.size_in_bytes(); + out_spir->data = malloc(out_spir->size); + memcpy(out_spir->data, buffer.data(), out_spir->size); + + return 0; +} + +int +clc_c_to_spirv(const struct clc_compile_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spirv) +{ + auto pair = clc_compile_to_llvm_module(args, logger); + if (!pair.first) + return -1; + return llvm_mod_to_spirv(std::move(pair.first), std::move(pair.second), logger, out_spirv); +} + +int +clc_spir_to_spirv(const struct clc_binary *in_spir, + const struct clc_logger *logger, + struct clc_binary *out_spirv) +{ + LLVMInitializeAllTargets(); + LLVMInitializeAllTargetInfos(); + LLVMInitializeAllTargetMCs(); + LLVMInitializeAllAsmPrinters(); + + std::unique_ptr llvm_ctx{ new LLVMContext }; + ::llvm::StringRef spir_ref(static_cast(in_spir->data), in_spir->size); + auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, ""), *llvm_ctx); + if (!mod) + return -1; + + return llvm_mod_to_spirv(std::move(mod.get()), std::move(llvm_ctx), logger, out_spirv); } class SPIRVMessageConsumer { @@ -762,20 +946,19 @@ private: int clc_link_spirv_binaries(const struct clc_linker_args *args, - struct spirv_binary *dst_bin, - const struct clc_logger *logger) + const struct clc_logger *logger, + struct clc_binary *out_spirv) { std::vector> binaries; for (unsigned i = 0; i < args->num_in_objs; i++) { - std::vector bin(args->in_objs[i]->spvbin.data, - args->in_objs[i]->spvbin.data + - (args->in_objs[i]->spvbin.size / 4)); + const uint32_t *data = static_cast(args->in_objs[i]->data); + std::vector bin(data, data + (args->in_objs[i]->size / 4)); binaries.push_back(bin); } SPIRVMessageConsumer msgconsumer(logger); - spvtools::Context context(SPV_ENV_UNIVERSAL_1_0); + spvtools::Context context(spirv_target); context.SetMessageConsumer(msgconsumer); spvtools::LinkerOptions options; options.SetAllowPartialLinkage(args->create_library); @@ -786,18 +969,80 @@ clc_link_spirv_binaries(const struct clc_linker_args *args, return -1; } - dst_bin->size = linkingResult.size() * 4; - dst_bin->data = static_cast(malloc(dst_bin->size)); - memcpy(dst_bin->data, linkingResult.data(), dst_bin->size); + out_spirv->size = linkingResult.size() * 4; + out_spirv->data = static_cast(malloc(out_spirv->size)); + memcpy(out_spirv->data, linkingResult.data(), out_spirv->size); return 0; } +int +clc_spirv_specialize(const struct clc_binary *in_spirv, + const struct clc_parsed_spirv *parsed_data, + const struct clc_spirv_specialization_consts *consts, + struct clc_binary *out_spirv) +{ + std::unordered_map> spec_const_map; + for (unsigned i = 0; i < consts->num_specializations; ++i) { + unsigned id = consts->specializations[i].id; + auto parsed_spec_const = std::find_if(parsed_data->spec_constants, + parsed_data->spec_constants + parsed_data->num_spec_constants, + [id](const clc_parsed_spec_constant &c) { return c.id == id; }); + assert(parsed_spec_const != parsed_data->spec_constants + parsed_data->num_spec_constants); + + std::vector words; + switch (parsed_spec_const->type) { + case CLC_SPEC_CONSTANT_BOOL: + words.push_back(consts->specializations[i].value.b); + break; + case CLC_SPEC_CONSTANT_INT32: + case CLC_SPEC_CONSTANT_UINT32: + case CLC_SPEC_CONSTANT_FLOAT: + words.push_back(consts->specializations[i].value.u32); + break; + case CLC_SPEC_CONSTANT_INT16: + words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i16); + break; + case CLC_SPEC_CONSTANT_INT8: + words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i8); + break; + case CLC_SPEC_CONSTANT_UINT16: + words.push_back((uint32_t)consts->specializations[i].value.u16); + break; + case CLC_SPEC_CONSTANT_UINT8: + words.push_back((uint32_t)consts->specializations[i].value.u8); + break; + case CLC_SPEC_CONSTANT_DOUBLE: + case CLC_SPEC_CONSTANT_INT64: + case CLC_SPEC_CONSTANT_UINT64: + words.resize(2); + memcpy(words.data(), &consts->specializations[i].value.u64, 8); + break; + } + + ASSERTED auto ret = spec_const_map.emplace(id, std::move(words)); + assert(ret.second); + } + + spvtools::Optimizer opt(spirv_target); + opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(std::move(spec_const_map))); + + std::vector result; + if (!opt.Run(static_cast(in_spirv->data), in_spirv->size / 4, &result)) + return false; + + out_spirv->size = result.size() * 4; + out_spirv->data = malloc(out_spirv->size); + memcpy(out_spirv->data, result.data(), out_spirv->size); + return true; +} + void -clc_dump_spirv(const struct spirv_binary *spvbin, FILE *f) +clc_dump_spirv(const struct clc_binary *spvbin, FILE *f) { - spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_0); - std::vector bin(spvbin->data, spvbin->data + (spvbin->size / 4)); + spvtools::SpirvTools tools(spirv_target); + const uint32_t *data = static_cast(spvbin->data); + std::vector bin(data, data + (spvbin->size / 4)); std::string out; tools.Disassemble(bin, &out, SPV_BINARY_TO_TEXT_OPTION_INDENT | @@ -806,7 +1051,13 @@ clc_dump_spirv(const struct spirv_binary *spvbin, FILE *f) } void -clc_free_spirv_binary(struct spirv_binary *spvbin) +clc_free_spir_binary(struct clc_binary *spir) +{ + free(spir->data); +} + +void +clc_free_spirv_binary(struct clc_binary *spvbin) { free(spvbin->data); } diff --git a/src/microsoft/clc/clc_helpers.h b/src/microsoft/clc/clc_helpers.h index 653e99a8a27adae174c7e42c6e9bd94285675cef..6edf261c3bc06b41865debb711379a7014a8be6a 100644 --- a/src/microsoft/clc/clc_helpers.h +++ b/src/microsoft/clc/clc_helpers.h @@ -38,29 +38,52 @@ extern "C" { #include #include -const struct clc_kernel_info * -clc_spirv_get_kernels_info(const struct spirv_binary *spvbin, - unsigned *num_kernels); +bool +clc_spirv_get_kernels_info(const struct clc_binary *spvbin, + const struct clc_kernel_info **kernels, + unsigned *num_kernels, + const struct clc_parsed_spec_constant **spec_constants, + unsigned *num_spec_constants, + const struct clc_logger *logger); void clc_free_kernels_info(const struct clc_kernel_info *kernels, unsigned num_kernels); int -clc_to_spirv(const struct clc_compile_args *args, - struct spirv_binary *spvbin, - const struct clc_logger *logger); +clc_c_to_spir(const struct clc_compile_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spir); + +int +clc_spir_to_spirv(const struct clc_binary *in_spir, + const struct clc_logger *logger, + struct clc_binary *out_spirv); + +int +clc_c_to_spirv(const struct clc_compile_args *args, + const struct clc_logger *logger, + struct clc_binary *out_spirv); int clc_link_spirv_binaries(const struct clc_linker_args *args, - struct spirv_binary *dst_bin, - const struct clc_logger *logger); + const struct clc_logger *logger, + struct clc_binary *out_spirv); + +int +clc_spirv_specialize(const struct clc_binary *in_spirv, + const struct clc_parsed_spirv *parsed_data, + const struct clc_spirv_specialization_consts *consts, + struct clc_binary *out_spirv); + +void +clc_dump_spirv(const struct clc_binary *spvbin, FILE *f); void -clc_dump_spirv(const struct spirv_binary *spvbin, FILE *f); +clc_free_spir_binary(struct clc_binary *spir); void -clc_free_spirv_binary(struct spirv_binary *spvbin); +clc_free_spirv_binary(struct clc_binary *spvbin); #define clc_log(logger, level, fmt, ...) do { \ if (!logger || !logger->level) break; \ diff --git a/src/microsoft/clc/clglon12compiler.def b/src/microsoft/clc/clglon12compiler.def deleted file mode 100644 index 924f7aa67232233ccae30af3b1d71ee122e21c47..0000000000000000000000000000000000000000 --- a/src/microsoft/clc/clglon12compiler.def +++ /dev/null @@ -1,12 +0,0 @@ -EXPORTS - clc_context_new - clc_free_context - clc_context_serialize - clc_context_free_serialized - clc_context_deserialize - clc_compile - clc_link - clc_free_object - clc_to_dxil - clc_free_dxil_object - clc_compiler_get_version diff --git a/src/microsoft/clc/clon12compiler.def b/src/microsoft/clc/clon12compiler.def new file mode 100644 index 0000000000000000000000000000000000000000..bb93d4b8c9e7b7af021266a36bd6083773a15933 --- /dev/null +++ b/src/microsoft/clc/clon12compiler.def @@ -0,0 +1,18 @@ +EXPORTS + clc_libclc_new + clc_free_libclc + clc_libclc_serialize + clc_libclc_free_serialized + clc_libclc_deserialize + clc_compile_c_to_spir + clc_free_spir + clc_compile_spir_to_spirv + clc_free_spirv + clc_compile_c_to_spirv + clc_link_spirv + clc_parse_spirv + clc_free_parsed_spirv + clc_specialize_spirv + clc_spirv_to_dxil + clc_free_dxil_object + clc_compiler_get_version diff --git a/src/microsoft/clc/compute_test.cpp b/src/microsoft/clc/compute_test.cpp index d19ff21095bc40229993bb72e9c840804e029541..e5faa4f5dde8e3b279ce3af197566ef27a77aa94 100644 --- a/src/microsoft/clc/compute_test.cpp +++ b/src/microsoft/clc/compute_test.cpp @@ -35,6 +35,8 @@ #include "compute_test.h" #include "dxcapi.h" +#include + using std::runtime_error; using Microsoft::WRL::ComPtr; @@ -48,8 +50,8 @@ enum compute_test_debug_flags { static const struct debug_named_value compute_debug_options[] = { { "experimental_shaders", COMPUTE_DEBUG_EXPERIMENTAL_SHADERS, "Enable experimental shaders" }, { "use_hw_d3d", COMPUTE_DEBUG_USE_HW_D3D, "Use a hardware D3D device" }, - { "optimize_libclc", COMPUTE_DEBUG_OPTIMIZE_LIBCLC, "Optimize the clc_context before using it" }, - { "serialize_libclc", COMPUTE_DEBUG_SERIALIZE_LIBCLC, "Serialize and deserialize the clc_context" }, + { "optimize_libclc", COMPUTE_DEBUG_OPTIMIZE_LIBCLC, "Optimize the clc_libclc before using it" }, + { "serialize_libclc", COMPUTE_DEBUG_SERIALIZE_LIBCLC, "Serialize and deserialize the clc_libclc" }, DEBUG_NAMED_VALUE_END }; @@ -617,31 +619,31 @@ ComputeTest::run_shader_with_raw_args(Shader shader, void ComputeTest::SetUp() { - static struct clc_context *compiler_ctx_g = nullptr; + static struct clc_libclc *compiler_ctx_g = nullptr; if (!compiler_ctx_g) { - clc_context_options options = { }; + clc_libclc_options options = { }; options.optimize = (debug_get_option_debug_compute() & COMPUTE_DEBUG_OPTIMIZE_LIBCLC) != 0; - compiler_ctx_g = clc_context_new(&logger, &options); + compiler_ctx_g = clc_libclc_new(&logger, &options); if (!compiler_ctx_g) throw runtime_error("failed to create CLC compiler context"); if (debug_get_option_debug_compute() & COMPUTE_DEBUG_SERIALIZE_LIBCLC) { void *serialized = nullptr; size_t serialized_size = 0; - clc_context_serialize(compiler_ctx_g, &serialized, &serialized_size); + clc_libclc_serialize(compiler_ctx_g, &serialized, &serialized_size); if (!serialized) throw runtime_error("failed to serialize CLC compiler context"); - clc_free_context(compiler_ctx_g); + clc_free_libclc(compiler_ctx_g); compiler_ctx_g = nullptr; - compiler_ctx_g = clc_context_deserialize(serialized, serialized_size); + compiler_ctx_g = clc_libclc_deserialize(serialized, serialized_size); if (!compiler_ctx_g) throw runtime_error("failed to deserialize CLC compiler context"); - clc_context_free_serialized(serialized); + clc_libclc_free_serialized(serialized); } } compiler_ctx = compiler_ctx_g; @@ -803,12 +805,16 @@ ComputeTest::compile(const std::vector &sources, for (unsigned i = 0; i < sources.size(); i++) { args.source.value = sources[i]; - auto obj = clc_compile(compiler_ctx, &args, &logger); - if (!obj) + clc_binary spirv{}; + if (!clc_compile_c_to_spirv(&args, &logger, &spirv)) throw runtime_error("failed to compile object!"); Shader shader; - shader.obj = std::shared_ptr(obj, clc_free_object); + shader.obj = std::shared_ptr(new clc_binary(spirv), [](clc_binary *spirv) + { + clc_free_spirv(spirv); + delete spirv; + }); shaders.push_back(shader); } @@ -822,7 +828,7 @@ ComputeTest::Shader ComputeTest::link(const std::vector &sources, bool create_library) { - std::vector objs; + std::vector objs; for (auto& source : sources) objs.push_back(&*source.obj); @@ -830,31 +836,66 @@ ComputeTest::link(const std::vector &sources, link_args.in_objs = objs.data(); link_args.num_in_objs = (unsigned)objs.size(); link_args.create_library = create_library; - struct clc_object *obj = clc_link(compiler_ctx, - &link_args, - &logger); - if (!obj) + clc_binary spirv{}; + if (!clc_link_spirv(&link_args, &logger, &spirv)) throw runtime_error("failed to link objects!"); ComputeTest::Shader shader; - shader.obj = std::shared_ptr(obj, clc_free_object); + shader.obj = std::shared_ptr(new clc_binary(spirv), [](clc_binary *spirv) + { + clc_free_spirv(spirv); + delete spirv; + }); if (!link_args.create_library) configure(shader, NULL); return shader; } +ComputeTest::Shader +ComputeTest::assemble(const char *source) +{ + spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_0); + std::vector binary; + if (!tools.Assemble(source, strlen(source), &binary)) + throw runtime_error("failed to assemble"); + + ComputeTest::Shader shader; + shader.obj = std::shared_ptr(new clc_binary{}, [](clc_binary *spirv) + { + free(spirv->data); + delete spirv; + }); + shader.obj->size = binary.size() * 4; + shader.obj->data = malloc(shader.obj->size); + memcpy(shader.obj->data, binary.data(), shader.obj->size); + + configure(shader, NULL); + + return shader; +} + void ComputeTest::configure(Shader &shader, const struct clc_runtime_kernel_conf *conf) { - struct clc_dxil_object *dxil; + if (!shader.metadata) { + shader.metadata = std::shared_ptr(new clc_parsed_spirv{}, [](clc_parsed_spirv *metadata) + { + clc_free_parsed_spirv(metadata); + delete metadata; + }); + if (!clc_parse_spirv(shader.obj.get(), NULL, shader.metadata.get())) + throw runtime_error("failed to parse spirv!"); + } - dxil = clc_to_dxil(compiler_ctx, shader.obj.get(), "main_test", conf, &logger); - if (!dxil) + shader.dxil = std::shared_ptr(new clc_dxil_object{}, [](clc_dxil_object *dxil) + { + clc_free_dxil_object(dxil); + delete dxil; + }); + if (!clc_spirv_to_dxil(compiler_ctx, shader.obj.get(), shader.metadata.get(), "main_test", conf, nullptr, &logger, shader.dxil.get())) throw runtime_error("failed to compile kernel!"); - - shader.dxil = std::shared_ptr(dxil, clc_free_dxil_object); } void diff --git a/src/microsoft/clc/compute_test.h b/src/microsoft/clc/compute_test.h index 11e7d1cc4d738489c8badd1f3f2d053b6a5a1981..9ac741fc16fd779adad05295000797761bafc40f 100644 --- a/src/microsoft/clc/compute_test.h +++ b/src/microsoft/clc/compute_test.h @@ -52,7 +52,8 @@ align(size_t value, unsigned alignment) class ComputeTest : public ::testing::Test { protected: struct Shader { - std::shared_ptr obj; + std::shared_ptr obj; + std::shared_ptr metadata; std::shared_ptr dxil; }; @@ -166,6 +167,9 @@ protected: link(const std::vector &sources, bool create_library = false); + Shader + assemble(const char *source); + void configure(Shader &shader, const struct clc_runtime_kernel_conf *conf); @@ -173,6 +177,33 @@ protected: void validate(Shader &shader); + template + Shader + specialize(Shader &shader, uint32_t id, T const& val) + { + Shader new_shader; + new_shader.obj = std::shared_ptr(new clc_binary{}, [](clc_binary *spirv) + { + clc_free_spirv(spirv); + delete spirv; + }); + if (!shader.metadata) + configure(shader, NULL); + + clc_spirv_specialization spec; + spec.id = id; + memcpy(&spec.value, &val, sizeof(val)); + clc_spirv_specialization_consts consts; + consts.specializations = &spec; + consts.num_specializations = 1; + if (!clc_specialize_spirv(shader.obj.get(), shader.metadata.get(), &consts, new_shader.obj.get())) + throw runtime_error("failed to specialize"); + + configure(new_shader, NULL); + + return new_shader; + } + enum ShaderArgDirection { SHADER_ARG_INPUT = 1, SHADER_ARG_OUTPUT = 2, @@ -314,7 +345,7 @@ protected: ID3D12GraphicsCommandList *cmdlist; ID3D12DescriptorHeap *uav_heap; - struct clc_context *compiler_ctx; + struct clc_libclc *compiler_ctx; UINT uav_heap_incr; int fence_value; diff --git a/src/microsoft/clc/meson.build b/src/microsoft/clc/meson.build index 6cf54f1cce0fd33ec097105ee43fb8ed738aa7f8..05b48b6a813dc0302c62a4d67fa1a2b6427256b3 100644 --- a/src/microsoft/clc/meson.build +++ b/src/microsoft/clc/meson.build @@ -49,11 +49,11 @@ files_libclc_compiler = files( ) libclc_compiler = shared_library( - 'clglon12compiler', + 'clon12compiler', [files_libclc_compiler, sha1_h], opencl_c_h, opencl_c_base_h, - vs_module_defs : 'clglon12compiler.def', + vs_module_defs : 'clon12compiler.def', include_directories : [inc_include, inc_src, inc_mapi, inc_mesa, inc_compiler, inc_gallium, inc_spirv], dependencies: [idep_nir_headers, dep_clang, dep_llvm, cc.find_library('version'), dep_llvmspirvlib, idep_mesautil, idep_libdxil_compiler, idep_nir, dep_spirv_tools] @@ -63,8 +63,8 @@ if dep_dxheaders.found() clc_compiler_test = executable('clc_compiler_test', ['clc_compiler_test.cpp', 'compute_test.cpp'], link_with : [libclc_compiler], - dependencies : [idep_gtest, idep_mesautil, idep_libdxil_compiler, dep_dxheaders], - include_directories : [inc_include, inc_src]) + dependencies : [idep_gtest, idep_mesautil, idep_libdxil_compiler, dep_dxheaders, dep_spirv_tools], + include_directories : [inc_include, inc_src, inc_spirv]) test('clc_compiler_test', clc_compiler_test, timeout: 180) endif