Skip to content
Snippets Groups Projects
Unverified Commit 1fea82aa authored by liukuikun's avatar liukuikun Committed by GitHub
Browse files

[Docs] update data element tutorials (#431)

* structure tutorials

* refine data element docs

* modify introduce

* fix comment

* fix comment

* fix comment
parent 5a9ac09f
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
...@@ -22,7 +22,7 @@ class LabelData(BaseDataElement): ...@@ -22,7 +22,7 @@ class LabelData(BaseDataElement):
assert isinstance(onehot, torch.Tensor) assert isinstance(onehot, torch.Tensor)
if (onehot.ndim == 1 and onehot.max().item() <= 1 if (onehot.ndim == 1 and onehot.max().item() <= 1
and onehot.min().item() >= 0): and onehot.min().item() >= 0):
return onehot.nonzero().squeeze() return onehot.nonzero().squeeze(-1)
else: else:
raise ValueError( raise ValueError(
'input is not one-hot and can not convert to label') 'input is not one-hot and can not convert to label')
......
...@@ -44,6 +44,11 @@ class TestLabelData(TestCase): ...@@ -44,6 +44,11 @@ class TestLabelData(TestCase):
label = LabelData.onehot_to_label(onehot) label = LabelData.onehot_to_label(onehot)
assert (label == item).all() assert (label == item).all()
assert label.device == item.device assert label.device == item.device
item = torch.tensor([2])
onehot = LabelData.label_to_onehot(item, num_classes=10)
label = LabelData.onehot_to_label(onehot)
assert label == item
assert label.device == item.device
@pytest.mark.skipif( @pytest.mark.skipif(
not torch.cuda.is_available(), reason='GPU is required!') not torch.cuda.is_available(), reason='GPU is required!')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment