linkedin / dagli

Framework for defining machine learning models, including feature generation and transformations, as directed acyclic graphs (DAGs).
BSD 2-Clause "Simplified" License
354 stars 40 forks source link

Example for NNClassification.withMultilabelLabelsInput #8

Closed cyberbeat closed 3 years ago

cyberbeat commented 3 years ago

I have defined a Placeholder with a label like this:

static public enum Label {some, values, ..}; EnumSet<Label> labels;

So that a record can have multiple labels.

This won't work:

ExtendedNode.Placeholder p = new ExtendedNode.Placeholder();
...
NNClassification<Label> myClassification = new NNClassification<Label>()
  .withFeaturesInput(denseLayers)
  .withMultilabelLabelsInput(p.asLabels());

How should I use the NNClassification Layer now?

jeffpasternack commented 3 years ago

Could you please let us know: (1) What's "p" in your example above? (2) What error/exception are you seeing?

One issue that's apparent from your example code is that you're not using generics to provide the type of the label on your NNClassification instance, which may result in compilation issues. I'm also not sure what the significance of your EnumSet is.

cyberbeat commented 3 years ago

1) p is the placeholder Variable. 2) I see the compiler error: withMultilabelLabelsInput(Producer<Iterable<? extends NodeBase.Label>>) in the type NNClassification<NodeBase.Label> is not applicable for the arguments (ExtendedNode.Labels).

If I understand the hierarchy right, the generated ExtendedNode.Labels class implements a Producer<EnumSet<NodeBase.Label>>?

About the generics: I pasted the code without code-quotes, so github stripped the <>.. I corrected this.

jeffpasternack commented 3 years ago

I think I understand what's going on (although I don't have complete information e.g. regarding whether ExtendedNode is a @Struct--I'll assume it is, from context). It looks like you've parameterized NNClassification to expect a set (or, more accurately, an Iterable) of labels of the type NodeBase.Label. Assuming that ExtendedNode is a @Struct, and Labels is a field on that @Struct with the type EnumSet<NodeBase.Label>, I think the issue is that the generic constraint on the type passed to withMultilabelLabelsInput(...) is too strict. I'll fix that and will update the ticket so you can retry. This may take a day or so; in the meantime you should be able to work around the the typing issue via unsafe casts (e.g. cast the argument passed to withMultilabelLabelsInput(...) to be (Producer<Iterable<? extends NodeBase.Label>>).

cyberbeat commented 3 years ago

Cast is not allowed:

@Struct("ExtendedNode")
abstract class NodeBase implements Serializable{
    private static final long serialVersionUID = -3225605876555058517L;
    ...some fields...
    static public enum Label {some, labels};
    EnumSet<Label> labels;
}

The compile error on cast is

Cannot cast from ExtendedNode.Labels to Producer<Iterable<? extends NodeBase.Label>>

This is, how Labels inheritance looks like:

public static class Labels extends AbstractPreparedTransformer1<ExtendedNode, EnumSet<NodeBase.Label>, Labels> { ...

What I want to do, simply to have multiple (or none) labels per example. Is this the right way to do it?

jeffpasternack commented 3 years ago

You may need to "cheat" and cast the argument to (Producer) first. Java is sometimes odd (IMO) in terms of what casts it wants to allow with warnings and what it wants to throw out compile errors for.

cyberbeat commented 3 years ago

Thanks, that did the trick!

jeffpasternack commented 3 years ago

Closing as beta7 has been pushed, fixing the underlying issue. The corresponding JARs should be available from Maven Central within 12-24 hours. Thanks for reporting!