Metal emulate geometry shaders using compute shaders

Here's a very speculative possibility depending on exactly what your geometry shader needs to do.

I'm thinking you can do it sort of "backwards" with just a vertex shader and no separate compute shader, at the cost of redundant work on the GPU. You would do a draw as if you had a buffer of all of the output vertices of the output primitives of the geometry shader. You would not actually have that on hand, though. You would construct a vertex shader that would calculate them in flight.

So, in the app code, calculate the number of output primitives and therefore the number of output vertices that would be produced for a given count of input primitives. Do a draw of the output primitive type with that many vertices.

You would not provide a buffer with the output vertex data as input to this draw.

You would provide the original index buffer and original vertex buffer as inputs to the vertex shader for that draw. The shader would calculate from the vertex ID which output primitive it's for, and which vertex of that primitive (e.g. for a triangle, vid / 3 and vid % 3, respectively). From the output primitive ID, it would calculate which input primitive would have generated it in the original geometry shader.

The shader would look up the indices for that input primitive from the index buffer and then the vertex data from the vertex buffer. (This would be sensitive to the distinction between a triangle list vs. triangle strip, for example.) It would apply any pre-geometry-shader vertex shading to that data. Then it would do the part of the geometry computation that contributes to the identified vertex of the identified output primitive. Once it has calculated the output vertex data, you can apply any post-geometry-shader vertex shading(?) that you want. The result is what it would return.

If the geometry shader can produce a variable number of output primitives per input primitive, well, at least you have a maximum number. So, you can draw the maximum potential count of vertices for the maximum potential count of output primitives. The vertex shader can do the computations necessary to figure out if the geometry shader would have, in fact, produced that primitive. If not, the vertex shader can arrange for the whole primitive to be clipped away, either by positioning it outside of the frustum or using a [[clip_distance]] property of the output vertex data.

This avoids ever storing the generated primitives in a buffer. However, it causes the GPU to do some of the pre-geometry-shader vertex shader and geometry shader calculations repeatedly. It will be parallelized, of course, but may still be slower than what you're doing now. Also, it may defeat some optimizations around fetching indices and vertex data that may be possible with more normal vertex shaders.


Here's an example conversion of your geometry shader:

#include <metal_stdlib>
using namespace metal;

struct VertexIn {
    // maybe need packed types here depending on your vertex buffer layout
    // can't use [[attribute(n)]] for these because Metal isn't doing the vertex lookup for us
    float3 position;
    float3 normal;
    float2 uv;
};

struct VertexOut {
    float3 position;
    float3 normal;
    float2 uv;
    float4 new_position [[position]];
};


vertex VertexOut foo(uint vid [[vertex_id]],
                     device const uint *indexes [[buffer(0)]],
                     device const VertexIn *vertexes [[buffer(1)]])
{
    VertexOut out;

    const uint triangle_id = vid / 3;
    const uint vertex_of_triangle = vid % 3;

    // indexes is for a triangle strip even though this shader is invoked for a triangle list.
    const uint index[3] = { indexes[triangle_id], index[triangle_id + 1], index[triangle_id + 2] };
    const VertexIn v[3] = { vertexes[index[0]], vertexes[index[1]], vertexes[index[2]] };

    float3 p = abs(cross(v[1].position - v[0].position, v[2].position - v[0].position));

    out.position = v[vertex_of_triangle].position;
    out.normal = v[vertex_of_triangle].normal;
    out.uv = v[vertex_of_triangle].uv;

    if (p.z > p.x && p.z > p.y)
    {
        out.new_position = float4(out.position.x, out.position.y, 0, 1);
    }
    else if (p.x > p.y && p.x > p.z)
    {
        out.new_position = float4(out.position.y, out.position.z, 0, 1);
    }
    else
    {
        out.new_position = float4(out.position.x, out.position.z, 0, 1);
    }

    return out;
}