Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] ConvTranspose1D slower on Webgpu than Wasm #23273

Open
gianlourbano opened this issue Jan 7, 2025 · 3 comments
Open

[js/webgpu] ConvTranspose1D slower on Webgpu than Wasm #23273

gianlourbano opened this issue Jan 7, 2025 · 3 comments
Labels
ep:WebGPU ort-web webgpu provider platform:web issues related to ONNX Runtime web; typically submitted using template

Comments

@gianlourbano
Copy link

Describe the issue

ConvTranpose1D with input shapes [8, 4098, 435], weights [4096, 1, 4098] strides 1024 and padding 0 appears to be slower on WebGPU than Wasm, with timings:

EP timing (m1 macbook pro)
wasm 6s
webgpu (latest chrome) 30s
webgpu (canary chrome) 18s

canary faster due to this bug

To reproduce

Simple torch script to generate the conv and convert it to onnx

import torch

class ConvTest (torch.nn.Module):
    def __init__(self, weight, stride, padding = 0):
        super(ConvTest, self).__init__()
        self.weight = weight
        self.stride = stride
        self.padding = padding
    
    def forward(self, x):
        return torch.nn.functional.conv_transpose1d(x, self.weight, stride=self.stride, padding=self.padding)

convtest = ConvTest(weight = torch.randn(4098, 1, 4096), stride = 1024)

input = torch.randn(8, 4098,  435)

torch.onnx.export(
    convtest,
    (input,),
    "convtest.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=20,
    dynamo=True,
    do_constant_folding=True,
    keep_initializers_as_inputs=True,
    # report=True,
    external_data=None,
    # verify=True
)

To test in browser:

       const session = await ort.InferenceSession.create("/convtest.onnx", {
            executionProviders: ["webgpu"],
            // logSeverityLevel: 0
        });

        const wgpu_profile = []

        ort.env.webgpu.profiling = {
            mode: "default",
            ondata: (data) => {
                wgpu_profile.push(data);
            }
        }

        const input_dims = [8, 4098, 435];
        const size = 8 * 4098 * 435;

        const no_chunks = 1;
        const chunks = [];

        for (let i = 0; i < no_chunks; i++) {
            const chunk = new Float32Array(size);
            chunks.push(chunk);
        }

        for(let i = 0; i < no_chunks; i++) {
            console.time("onnx step " + i);
            const input = new ort.Tensor("float32", chunks[i], input_dims);
            const output = await session.run({input});
            console.timeEnd("onnx step " + i);
        }

        await session.release();

        wgpu_profile.sort((a, b) => (a.endTime-a.startTime) - (b.endTime-b.startTime));

        wgpu_profile.forEach((kernel) => {
            console.log(`${kernel.kernelType} (${kernel.kernelName}) took ${(kernel.endTime - kernel.startTime) / 1000 / 1000} ms`);
        })

Urgency

Urgent

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.21.0-dev.20241224-2d05c4bcd9

Execution Provider

'webgpu' (WebGPU), 'wasm'/'cpu' (WebAssembly CPU)

@gianlourbano gianlourbano added the platform:web issues related to ONNX Runtime web; typically submitted using template label Jan 7, 2025
@gianlourbano
Copy link
Author

@qjia7 @gyagp could you please take a look? Maybe it has something to do with this pr

@github-actions github-actions bot added the ep:WebGPU ort-web webgpu provider label Jan 7, 2025
@qjia7
Copy link
Contributor

qjia7 commented Jan 8, 2025

@gianlourbano I can reproduce it. Will take a look, thanks.

guschmue pushed a commit that referenced this issue Jan 9, 2025
### Description
<!-- Describe your changes. -->
BUG #23273

With this change, I see the convTranspose time in that bug becomes ~7s
from ~90s on my Meteor Lake.

This PR does below things:
1. Use stride to update the increasement in the loop.
In the bug, the stride is 1024, which can greatly reduce the loop times.
2. Support components for A to reduce the memory access times.
3. When output channels is 1, the b components can be same with A to
further reduce the memory access times.
@gianlourbano
Copy link
Author

gianlourbano commented Jan 10, 2025

Thanks for the help @qjia7 ! Do you think there's more room for improvement? The same op in torch/onnx python cpu takes about 400-600ms

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider platform:web issues related to ONNX Runtime web; typically submitted using template
Projects
None yet
Development

No branches or pull requests

2 participants