llama.cpp/ggml/src/vulkan-shaders/upscale.comp
0cc4m a3738b2fa7 vulkan : implement Stable Diffusion operators (ggml/904)
* Fix Vulkan repeat op

* Implement Vulkan concat op

* Delete old Vulkan shader generator

* Implement Vulkan im2col op

* Implement Vulkan unary gelu_quick op

* Implement Vulkan group_norm op

* Implement Vulkan timestep_embedding op

* Implement Vulkan upscale op

* Fix Vulkan vk_context tensor extra index issue

* Fix Vulkan matmul shader parameter bug

* Properly fix Vulkan matmul shader parameter bug

* Add Vulkan ADD f16 + f32 -> f16 operator support

* Implement Vulkan tanh op

* Fix Vulkan group count too large Validation error on non-Nvidia GPUs

* Throw error when too much memory is requested

* Fix another Vulkan group count too large Validation error on non-Nvidia GPUs

* Fix matmul MMQ condition

* Implement Vulkan pad op

* Fix Vulkan crash when tensor is used multiple times in a compute graph

* Add Vulkan CONCAT f16 + f16 -> f16 op

* Add Vulkan LEAKY_RELU op
2024-08-05 08:50:57 +03:00

36 lines
1.1 KiB
Text

#version 450
layout (push_constant) uniform parameter
{
uint ne; uint d_offset;
uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13;
float sf0; float sf1; float sf2; float sf3;
} p;
#include "types.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (idx >= p.ne) {
return;
}
const uint i10 = idx % p.ne10;
const uint i11 = (idx / p.ne10) % p.ne11;
const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
const uint i00 = uint(i10 / p.sf0);
const uint i01 = uint(i11 / p.sf1);
const uint i02 = uint(i12 / p.sf2);
const uint i03 = uint(i13 / p.sf3);
data_d[p.d_offset + idx] = D_TYPE(data_a[i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
}