Skip to content

Commit

Permalink
Make plugin task first-class citizen (#268)
Browse files Browse the repository at this point in the history
* Manually pick proto change from upstream PR

Signed-off-by: Hongxin Liang <[email protected]>

* Make plugin task first-class citizen

Signed-off-by: Hongxin Liang <[email protected]>

* Fix file headers

Signed-off-by: Hongxin Liang <[email protected]>

---------

Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix authored Dec 11, 2023
1 parent 692fe9e commit 78cd22d
Show file tree
Hide file tree
Showing 15 changed files with 608 additions and 11 deletions.
8 changes: 8 additions & 0 deletions flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ message RuntimeMetadata {

//+optional It can be used to provide extra information about the runtime (e.g. python, golang... etc.).
string flavor = 3;

//+optional It can be used to provide extra information for the plugin.
PluginMetadata plugin_metadata = 4;
}

message PluginMetadata {
//+optional It can be used to decide use sync plugin or async plugin during runtime.
bool is_sync_plugin = 1;
}

// Task Metadata
Expand Down
22 changes: 22 additions & 0 deletions flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.api.v1;

/** A task that is handled by a Flyte backend plugin instead of run as a container. */
public interface PluginTask extends Task {
boolean isSyncPlugin();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.api.v1;

/** A registrar that creates {@link PluginTask} instances. */
public abstract class PluginTaskRegistrar implements Registrar<TaskIdentifier, PluginTask> {}
8 changes: 8 additions & 0 deletions flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
/**
* A Task structure that uniquely identifies a task in the system. Tasks are registered as a first
* step in the system.
*
* <p>FIXME: consider offering TaskMetadata instead of having everything in TaskTemplate, see
* https://github.com/flyteorg/flyte/blob/ea72bbd12578d64087221592554fb71c368f8057/flyteidl/protos/flyteidl/core/tasks.proto#L90
*/
@AutoValue
public abstract class TaskTemplate {
Expand Down Expand Up @@ -64,6 +67,9 @@ public abstract class TaskTemplate {
*/
public abstract boolean cacheSerializable();

/** Indicates whether to use sync plugin or async plugin to handle this task. */
public abstract boolean isSyncPlugin();

public abstract Builder toBuilder();

public static Builder builder() {
Expand All @@ -89,6 +95,8 @@ public abstract static class Builder {

public abstract Builder cacheSerializable(boolean cacheSerializable);

public abstract Builder isSyncPlugin(boolean isSyncPlugin);

public abstract TaskTemplate build();
}
}
115 changes: 115 additions & 0 deletions flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.flytekit;

import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.flyte.api.v1.PartialTaskIdentifier;

/** A task that is handled by a Flyte backend plugin instead of run as a container. */
public abstract class SdkPluginTask<InputT, OutputT> extends SdkTransform<InputT, OutputT> {

private final SdkType<InputT> inputType;
private final SdkType<OutputT> outputType;

/**
* Called by subclasses passing the {@link SdkType}s for inputs and outputs.
*
* @param inputType type for inputs.
* @param outputType type for outputs.
*/
public SdkPluginTask(SdkType<InputT> inputType, SdkType<OutputT> outputType) {
this.inputType = inputType;
this.outputType = outputType;
}

public abstract String getType();

@Override
public SdkType<InputT> getInputType() {
return inputType;
}

@Override
public SdkType<OutputT> getOutputType() {
return outputType;
}

/** Specifies custom data that can be read by the backend plugin. */
public SdkStruct getCustom() {
return SdkStruct.empty();
}

/**
* Number of retries. Retries will be consumed when the task fails with a recoverable error. The
* number of retries must be less than or equals to 10.
*
* @return number of retries
*/
public int getRetries() {
return 0;
}

/**
* Indicates whether the system should attempt to look up this task's output to avoid duplication
* of work.
*/
public boolean isCached() {
return false;
}

/** Indicates a logical version to apply to this task for the purpose of cache. */
public String getCacheVersion() {
return null;
}

/**
* Indicates whether the system should attempt to execute cached instances in serial to avoid
* duplicate work.
*/
public boolean isCacheSerializable() {
return false;
}

@Override
SdkNode<OutputT> apply(
SdkWorkflowBuilder builder,
String nodeId,
List<String> upstreamNodeIds,
@Nullable SdkNodeMetadata metadata,
Map<String, SdkBindingData<?>> inputs) {
PartialTaskIdentifier taskId = PartialTaskIdentifier.builder().name(getName()).build();
List<CompilerError> errors =
Compiler.validateApply(nodeId, inputs, getInputType().getVariableMap());

if (!errors.isEmpty()) {
throw new CompilerException(errors);
}

return new SdkTaskNode<>(
builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, outputType);
}

/**
* Signaling whether this task is supposed to be handled by a synchronous backend plugin,
* defaulting to false.
*/
public boolean isSyncPlugin() {
return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.flytekit;

import com.google.auto.service.AutoService;
import java.util.HashMap;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.flyte.api.v1.PluginTask;
import org.flyte.api.v1.PluginTaskRegistrar;
import org.flyte.api.v1.RetryStrategy;
import org.flyte.api.v1.Struct;
import org.flyte.api.v1.TaskIdentifier;
import org.flyte.api.v1.TypedInterface;

/**
* Default implementation of a {@link PluginTaskRegistrar} that discovers {@link SdkPluginTask}s
* implementation via {@link ServiceLoader} mechanism. Plugin tasks implementations must use
* {@code @AutoService(SdkPluginTask.class)} or manually add their fully qualifies name to the
* corresponding file.
*
* @see ServiceLoader
*/
@AutoService(PluginTaskRegistrar.class)
public class SdkPluginTaskRegistrar extends PluginTaskRegistrar {
private static final Logger LOG = Logger.getLogger(SdkPluginTaskRegistrar.class.getName());

static {
// enable all levels for the actual handler to pick up
LOG.setLevel(Level.ALL);
}

private static class PluginTaskImpl<InputT, OutputT> implements PluginTask {
private final SdkPluginTask<InputT, OutputT> sdkTask;

private PluginTaskImpl(SdkPluginTask<InputT, OutputT> sdkTask) {
this.sdkTask = sdkTask;
}

@Override
public String getType() {
return sdkTask.getType();
}

@Override
public Struct getCustom() {
return sdkTask.getCustom().struct();
}

@Override
public TypedInterface getInterface() {
return TypedInterface.builder()
.inputs(sdkTask.getInputType().getVariableMap())
.outputs(sdkTask.getOutputType().getVariableMap())
.build();
}

@Override
public RetryStrategy getRetries() {
return RetryStrategy.builder().retries(sdkTask.getRetries()).build();
}

@Override
public boolean isCached() {
return sdkTask.isCached();
}

@Override
public String getCacheVersion() {
return sdkTask.getCacheVersion();
}

@Override
public boolean isCacheSerializable() {
return sdkTask.isCacheSerializable();
}

@Override
public String getName() {
return sdkTask.getName();
}

@Override
public boolean isSyncPlugin() {
return sdkTask.isSyncPlugin();
}
}

/**
* Load {@link SdkPluginTask}s using {@link ServiceLoader}.
*
* @param env env vars in a map that would be used to pick up the project, domain and version for
* the discovered tasks.
* @param classLoader class loader to use when discovering the task using {@link
* ServiceLoader#load(Class, ClassLoader)}
* @return a map of {@link SdkPluginTask}s by its task identifier.
*/
@Override
@SuppressWarnings("rawtypes")
public Map<TaskIdentifier, PluginTask> load(Map<String, String> env, ClassLoader classLoader) {
ServiceLoader<SdkPluginTask> loader = ServiceLoader.load(SdkPluginTask.class, classLoader);

LOG.fine("Discovering SdkPluginTask");

Map<TaskIdentifier, PluginTask> tasks = new HashMap<>();
SdkConfig sdkConfig = SdkConfig.load(env);

for (SdkPluginTask<?, ?> sdkTask : loader) {
String name = sdkTask.getName();
TaskIdentifier taskId =
TaskIdentifier.builder()
.domain(sdkConfig.domain())
.project(sdkConfig.project())
.name(name)
.version(sdkConfig.version())
.build();
LOG.fine(String.format("Discovered [%s]", name));

PluginTask task = new PluginTaskImpl<>(sdkTask);
PluginTask previous = tasks.put(taskId, task);

if (previous != null) {
throw new IllegalArgumentException(
String.format("Discovered a duplicate task [%s] [%s] [%s]", name, task, previous));
}
}

return tasks;
}
}
Loading

0 comments on commit 78cd22d

Please sign in to comment.