-
Notifications
You must be signed in to change notification settings - Fork 74.9k
[XLA] Add a new XLA mode: XLALite #34655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…LA and the tf_xla_supported_nodes=FUSIBLE flag. This make using that mode easier.
tensorflow/compiler/jit/flags.cc
Outdated
| "things very likely to be improved; 2 = on for everything. " | ||
| "things very likely to be improved; 2 = on for everything; " | ||
| "fusible = only for Tensorflow operations that XLA knows how to " | ||
| "fuse. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add something like (experimental) to indicate the feature may change in backward-incompatible ways going forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not done? Also should this be under a different flag if this can change in unpredictable ways?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. I added the experimental mark to the new flags documentation. Added here too now.
tensorflow/compiler/jit/flags.cc
Outdated
| "If multiple, separate them with commas. Shortcuts: " | ||
| " PW: All point-wise operations." | ||
| " RED: All reduction operations." | ||
| " SMALL: Mixed small operations." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what small means in this context? How did you choose the ops that were a good fit for the SMALL category?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small is like Shape, Rank, Range and a few others I didn't know in which category to put. I could split that section into 'SMALL' and MIXED. Do you have a better idea of how to split that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you said you rarely used the subcategories of FUSIBLE anyway, I wouldn't split them further. I was just wondering in what way the TF ops are SMALL. If this is a category for ops that just wouldn't fit anywhere else, I'd call it MISC to make this obvious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
tensorflow/compiler/jit/flags.h
Outdated
| int32 tf_xla_max_cluster_size; | ||
|
|
||
| // If non-empty, limit XLA clustering to the following TF operations. | ||
| string tf_xla_supported_ops; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: tf_xla_ops_to_cluster is a little clearer than 'tf_xla_supported_ops'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
tensorflow/compiler/jit/flags.cc
Outdated
| "(LRN, LRNGrad)." | ||
| " BN: TF FusedBatchNorm* operations." | ||
| " FUSIBLE: All TF operations that XLA can fuse (All the above). " | ||
| "You can also put any TF operation name."), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are TF op names expected to be fully-qualified? Please provide an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added an example. I'm not sure what you mean by fully-qualified name? I think TF operation have just one unique name like Add, Matmul, Sum,... It is those that should be used in this version.
| return true; | ||
| } | ||
|
|
||
| const absl::flat_hash_map<string, std::vector<string>> whitelist_table = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What invariants must hold for this list? Or phrased differently, at what points does this list to be updated? Is it meant to be complete, i.e. all point-wise operations supported by XLA:GPU are listed here? In you experiments, did you find subsets of FUSIBLE useful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every times the TF-XLA bridge support a new TF operation, we need to check if we want to support it by default or not. So we should document that somewhere to help this not to be forgotten. Any idea where we should document that?
I didn't used the shortcut much. What I used is FUSIBLE plus other operation that I wasn't sure if we want to include or not. But if someone want to start to play more with it, I think the shortcut would be useful.
To make this list, I passed over the TF operations that the bridge knows and selected those that I was sure what they do and that I was sure XLA could fuse them. I could have missed some. In 2 benchmarks, I found some that could maybe fused that I didn't include. I timed when including them and it slowed down XLALite compared to not having them. So I didn't included them. They where:
"ReadVariableOp", "VarIsInitializedOp", "VariableShape",
"ResourceApplyCenteredRMSProp", "ResourceApplyRMSProp",
"ResourceScatterAdd", "ResourceApplyAdam"}```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sanjoy Can you recommend a good spot for this documentation? Do you have an opinion regarding the TF Ops Frederic did not include?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As for the documentation I would vote for creating a unit test that checks that tf2xla supported ops are either whitelisted or explicitly blacklisted. See resource_operation_table_test.cc.
Do you have an opinion regarding the TF Ops Frederic did not include?
Right it isn't totally obvious which ops were included. I think we need some comments to describe the "format" of whitelist_table.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea to have a test that force to have all XLA/TF operations registered to be whiteliste or blacklisted. It will guaranty that we make a decision when new operations are supported.
I added this test.
tensorflow/compiler/jit/flags.cc
Outdated
| "things very likely to be improved; 2 = on for everything. " | ||
| "things very likely to be improved; 2 = on for everything; " | ||
| "fusible = only for Tensorflow operations that XLA knows how to " | ||
| "fuse. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not done? Also should this be under a different flag if this can change in unpredictable ways?
|
|
||
| std::unique_ptr<absl::flat_hash_set<string>> GetWhitelist() { | ||
| MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); | ||
| auto whitelist = absl::WrapUnique(new absl::flat_hash_set<string>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make_unique?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you suggested elsewhere, I return by copy instead of by reference now. As this is small and the compiler probably optimize it, it makes the code simpler.
|
|
||
| if (VLOG_IS_ON(2) && whitelist->size() > 0) { | ||
| std::vector<string> vwhitelist(whitelist->begin(), whitelist->end()); | ||
| std::sort(vwhitelist.begin(), vwhitelist.end()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
absl::c_sort
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
| } | ||
| } | ||
|
|
||
| if (VLOG_IS_ON(2) && whitelist->size() > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
!whitelist->empty()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| } else if (whitelist_table.contains(s)) { | ||
| auto v = whitelist_table.at(s); | ||
| whitelist->insert(v.begin(), v.end()); | ||
| } else if (s.size() > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
!s.empty()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| whitelist->insert(v.begin(), v.end()); | ||
| } else if (s.size() > 0) { | ||
| // Should be a user provided TF operation. | ||
| whitelist->insert(string(s)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we VLOG(5) here or something to avoid misspellings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I already do a misspelling check on line 1195. It is there to help error reporting.
| } | ||
| } | ||
|
|
||
| if (VLOG_IS_ON(2) && whitelist->size() > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just sort on all codepaths, then this entire branch becomes VLOG(2) << ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To remove the time for the sort when we do not print.
I tried to change the container to the sorted container absl::btree_set to remove the need for a sort. But btree_set isn't available in my version of absl. I'm not sure it is a good reason to bump TF absl required version and never changed TF dependency version.
If I put a VLOG(2) on all code path, this will make the output verbose even when this feature isn't used. Are you suggesting doing that? I wouldn't have a tendency of adding extra useless verbose output here. Do you see value to always print it even when not used?
|
|
||
| auto whitelist = GetWhitelist(); | ||
|
|
||
| auto vall_ops = XlaOpRegistry::GetAllRegisteredOps(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicit type would be useful here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| auto vall_ops = XlaOpRegistry::GetAllRegisteredOps(); | ||
| absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end()); | ||
| // Check that user's provided TF operation really exists. | ||
| for (auto s = whitelist->begin(); s != whitelist->end(); ++s) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for-each loop instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
@thomasjoerg the issue with the non-trivially destructible global remains, it will break the build. |
| return true; | ||
| } | ||
|
|
||
| const absl::flat_hash_map<string, std::vector<string>> whitelist_table = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please document the format of this table.
We also don't allow non-trivial global destructors because they don't play well with multi-threading. I think a better phrasing is:
absl::flat_hash_map<string, std::vector<string>> *CreateWhitelist() {
absl::flat_hash_map<string, std::vector<string>>* result = new ...;
// Use explicit code to populate "result", possibly with comments.
}
const absl::flat_hash_map<string, std::vector<string>>& GetOrCreateWhitelist() {
static absl::flat_hash_map<string, std::vector<string>>* whitelist = CreateWhitelist();
return *whitelist;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed the non-trivial global destructors.
I also added documentation of the format. If you wanted more then this, just tell me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to have static in CreateWhitelist in order not accidentally create a leak if someone else calls it, also it seems this is how it is usually done in XLA, e.g. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/transfer_manager.cc#L41
|
|
||
| absl::flat_hash_map<string, std::vector<string>>* CreateWhitelist() { | ||
| // Table format: category name: {list of TF operations in that category} | ||
| absl::flat_hash_map<string, std::vector<string>>* result = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be static? At the moment this seems to leak on every invocation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified it. See other comment: #34655 (comment)
| return true; | ||
| } | ||
|
|
||
| const absl::flat_hash_map<string, std::vector<string>> whitelist_table = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to have static in CreateWhitelist in order not accidentally create a leak if someone else calls it, also it seems this is how it is usually done in XLA, e.g. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/transfer_manager.cc#L41
|
I can't reply to the comment, so replying here. |
|
@rthadur Did CopyBara get stuck somehow? It's been two days since the PR was approved. Can you kick CopyBara? |
|
@thomasjoerg we can also force run copybara by adding the |
@cheshire Frocing a Kokoro did not do the trick. I imported the PR manually. |
|
Note, I forgot to update the interface change in my test. So my test currently broke. |
make it easier to find which operation. For example, the nodename can be AvgPool2d, while the TF operation is AvgPool.
PiperOrigin-RevId: 285730750 Change-Id: Ib53f29df2e956b8c4904d08af3d6f33f1c419a9f
|
I made a new PR with the last commit that was missing: |
Add 2 new XLA flags:
TF_XLA_FLAGS=--tf_xla_ops_to_cluster=[FUSIBLE,...]TF_XLA_FLAGS=--tf_xla_auto_jit=fusibleThis is a shortcut to
TF_XLA_FLAGS=--tf_xla_ops_to_cluster=FUSIBLE TF_XLA_FLAGS=--tf_xla_auto_jit=1This enables XLA but only for a subset of TF operations that XLA know how to fuse together. This allows using XLA operations fusion capabilities while removing some of the current XLA slow down case. In most cases where XLA is slower than TF classic, XLA isn't slower the TF classic. In some cases where XLA give speed up vs TF classic, XLALite give a good part of that speed up.
The flag tf_xla_supported_ops can accept TF operation names and/or predefined groups of operation. The group FUSIBLE includes all the groups defined. Multiple value can be passed by separating them by comma.