diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c index e3e4e61802c..59ec75da885 100644 --- a/src/compiler/spirv/vtn_opencl.c +++ b/src/compiler/spirv/vtn_opencl.c @@ -508,8 +508,25 @@ handle_special(struct vtn_builder *b, uint32_t opcode, return nir_cross3(nb, srcs[0], srcs[1]); case OpenCLstd_Fdim: return nir_fdim(nb, srcs[0], srcs[1]); - case OpenCLstd_Mad: - return nir_fmad(nb, srcs[0], srcs[1], srcs[2]); + case OpenCLstd_Mad: { + /* The spec says mad is + * + * Implemented either as a correctly rounded fma or as a multiply + * followed by an add both of which are correctly rounded + * + * So lower to fmul+fadd if we have to, but fuse to an ffma if the backend + * supports that. This can be significantly faster. + */ + bool lower = + ((nb->shader->options->lower_ffma16 && srcs[0]->bit_size == 16) || + (nb->shader->options->lower_ffma32 && srcs[0]->bit_size == 32) || + (nb->shader->options->lower_ffma64 && srcs[0]->bit_size == 64)); + + if (lower) + return nir_fmad(nb, srcs[0], srcs[1], srcs[2]); + else + return nir_ffma(nb, srcs[0], srcs[1], srcs[2]); + } case OpenCLstd_Maxmag: return nir_maxmag(nb, srcs[0], srcs[1]); case OpenCLstd_Minmag: