diff --git a/src/webgl/p5.Shader.js b/src/webgl/p5.Shader.js index a82f112361..2628e73f2b 100644 --- a/src/webgl/p5.Shader.js +++ b/src/webgl/p5.Shader.js @@ -431,6 +431,39 @@ p5.Shader = class { */ modify(hooks) { p5._validateParameters('p5.Shader.modify', arguments); + + // Internal helper to normalize shader hooks + // Automatically appends a return statement when: + // - The hook's return type matches its first parameter type (e.g. Inputs getX(Inputs x)) + // - No explicit 'return' is present in the user-provided code + const normalizeReturnIfMissing = (hookDef, impl) => { + // Example: hookDef = "Inputs getPixelInputs" + // Example: impl = "(Inputs inputs) { ... }" + const defMatch = /^(\w+)\s+(\w+)$/.exec(hookDef.trim()); + if (!defMatch) return impl; + const [, returnType] = defMatch; + + // Skip void-return hooks + if (returnType === 'void') return impl; + + // Strip // and /* */ comments before searching for 'return' + const withoutComments = impl + .replace(/\/\/.*$/gm, '') + .replace(/\/\*[\s\S]*?\*\//g, ''); + if (/\breturn\b/.test(withoutComments)) return impl; + + // Extract the first parameter type and name + const sigMatch = /^\s*\(\s*([\w\s[\]]+)\s+(\w+)\s*\)/.exec(impl); + if (!sigMatch) return impl; + const [, paramType, paramName] = sigMatch; + + // Only normalize when return type matches first param type + if (paramType.trim() !== returnType.trim()) return impl; + + // Append 'return ;' before the last closing brace + return impl.replace(/\}\s*$/, ` return ${paramName};\n}`); + }; + const newHooks = { vertex: {}, fragment: {}, @@ -446,9 +479,9 @@ p5.Shader = class { newHooks.fragment.declarations = (newHooks.fragment.declarations || '') + '\n' + hooks[key]; } else if (this.hooks.vertex[key]) { - newHooks.vertex[key] = hooks[key]; + newHooks.vertex[key] = normalizeReturnIfMissing(key, hooks[key]); } else if (this.hooks.fragment[key]) { - newHooks.fragment[key] = hooks[key]; + newHooks.fragment[key] = normalizeReturnIfMissing(key, hooks[key]); } else { newHooks.helpers[key] = hooks[key]; } diff --git a/test/unit/webgl/p5.Shader.js b/test/unit/webgl/p5.Shader.js index 044bf6ec0a..c32297c36f 100644 --- a/test/unit/webgl/p5.Shader.js +++ b/test/unit/webgl/p5.Shader.js @@ -385,6 +385,95 @@ suite('p5.Shader', function() { }); expect(modified.fragSrc()).to.match(/#define AUGMENTED_HOOK_getVertexColor/); }); + + test('auto-returns when param/return types match (no explicit return)', function() { + const modified = myShader.modify({ + 'vec4 getVertexColor': `(vec4 c) { + c.rgb = vec3(1.0, 0.0, 0.0); + }` + }); + expect(modified.fragSrc()).to.match(/#define AUGMENTED_HOOK_getVertexColor/); + expect(modified.fragSrc()).to.match(/getVertexColor[\s\S]*?\{[\s\S]*?return\s+c\s*;[\s\S]*?\}/); + }); + + test('explicit return is preserved and not duplicated', function() { + const modified = myShader.modify({ + 'vec4 getVertexColor': `(vec4 c) { + c.rgb *= 0.5; + return c; + }` + }); + expect(modified.fragSrc()).to.match(/#define AUGMENTED_HOOK_getVertexColor/); + + const body = modified.fragSrc().match(/getVertexColor[\s\S]*?\{([\s\S]*?)\}/)[1]; + const matches = (body.match(/return\s+c\s*;/g) || []).length; + expect(matches).to.equal(1); + }); + + test('commented return does not block normalization', function() { + const modified = myShader.modify({ + 'vec4 getVertexColor': `(vec4 c) { + /* return c; */ + // return c; + c.a = 1.0; + }` + }); + expect(modified.fragSrc()).to.match(/#define AUGMENTED_HOOK_getVertexColor/); + expect(modified.fragSrc()).to.match(/getVertexColor[\s\S]*?return\s+c\s*;/); + }); + + test('void hooks are not normalized', function() { + const modified = myShader.modify({ + 'void beforeFragment': `() { + // no-op + }`, + 'vec4 getVertexColor': `(vec4 c) { + return c; + }` + }); + + const src = modified.fragSrc(); + const bfBody = src.match(/HOOK_beforeFragment[\s\S]*?\{([\s\S]*?)\}/); + if (bfBody) { + expect(bfBody[1]).not.to.match(/\breturn\b/); + } + }); + + test('mismatched types are not auto-returned and cause compiler error', function() { + const bad = myp5.createShader( + ` + precision highp float; + attribute vec3 aPosition; + uniform mat4 uModelViewMatrix, uProjectionMatrix; + void main() { + gl_Position = uProjectMatrix * uModelViewMatrix * vec4(aPosition, 1.0); + } + `, + ` + precision highp float; + void main() { + gl_FragColor = HOOK_getVertexColor(vec4(0.0, 1.0, 0.0, 1.0)); + } + `, + { + fragment: { + 'vec4 getVertexColor': '(vec2 uv) { }' + } + } + ); + + let threw = false; + try { + bad.bindShader(); + } catch (e) { + threw = true; + } finally { + try { + bad.unbindShader(); + } catch (_) {} + } + expect(threw).to.equal(true); + }); }); }); });